From 36d50de50e30e92950070c3449b99d78143fb221 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 28 Mar 2025 00:04:31 +0800 Subject: [PATCH 001/443] ckmoe: change cmake; use smaller shape for i4 (#2027) * change cmake; use smaller shape for i4 * fix pki4 run * fix typo * fix runtime arch logic for moe_gemm2 example --------- Co-authored-by: coderfeli Co-authored-by: illsilin --- example/65_gemm_multiply_multiply/CMakeLists.txt | 4 ++-- .../65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp | 10 +++++----- .../65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 38b42fefc4..95fd8bace8 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -3,14 +3,14 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_mult add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -# add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) list(APPEND gpu_list gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - # add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) + add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) set(target 1) endif() diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 17f4cd8a3f..1102ce1054 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -191,14 +191,14 @@ int main(int argc, char* argv[]) // experts = 8 // per expert: // GEMM shape - ck::index_t N = 14336 * 2; - ck::index_t K = 4096; + ck::index_t N = 4096 * 2; + ck::index_t K = 6144; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 64; + ck::index_t tokens = 644; ck::index_t topk = 2; if(argc == 1) @@ -440,8 +440,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 8441862004..528503a2c4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -407,8 +407,8 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument) || + !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " From a426f673018465e057f19355c444ff1c0eb2ff35 Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Thu, 27 Mar 2025 17:13:18 -0400 Subject: [PATCH 002/443] creation of install doc and refactor of doc in general (#1908) * creation of install doc and refactor of doc in general * updates based on review comments * updated based on review comments * updated readme and contributors markdown * added extra note to not use -j on its own * added note about smoke tests and regression tests * made changes as per Illia's feedback --------- Co-authored-by: Aviral Goel --- CONTRIBUTORS.md | 5 +- README.md | 6 +- .../Composable-Kernel-math.rst} | 11 +- .../Composable-Kernel-structure.rst | 29 +++ docs/conceptual/what-is-ck.rst | 41 ----- docs/index.rst | 23 +-- docs/install/Composable-Kernel-Docker.rst | 16 ++ docs/install/Composable-Kernel-install.rst | 72 ++++++++ .../Composable-Kernel-prerequisites.rst | 32 ++++ docs/install/dockerhub.rst | 101 ----------- ...st => Composable-Kernel-API-reference.rst} | 16 +- ...pper.rst => Composable-Kernel-wrapper.rst} | 13 +- docs/sphinx/_toc.yml.in | 40 +++-- docs/tutorial/Composable-Kernel-examples.rst | 40 +++++ docs/tutorial/tutorial_hello_world.rst | 165 ------------------ 15 files changed, 244 insertions(+), 366 deletions(-) rename docs/{reference/Supported_Primitives_Guide.rst => conceptual/Composable-Kernel-math.rst} (85%) create mode 100644 docs/conceptual/Composable-Kernel-structure.rst delete mode 100644 docs/conceptual/what-is-ck.rst create mode 100644 docs/install/Composable-Kernel-Docker.rst create mode 100644 docs/install/Composable-Kernel-install.rst create mode 100644 docs/install/Composable-Kernel-prerequisites.rst delete mode 100644 docs/install/dockerhub.rst rename docs/reference/{API_Reference_Guide.rst => Composable-Kernel-API-reference.rst} (79%) rename docs/reference/{wrapper.rst => Composable-Kernel-wrapper.rst} (88%) create mode 100644 docs/tutorial/Composable-Kernel-examples.rst delete mode 100644 docs/tutorial/tutorial_hello_world.rst diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8ef5c2b726..0900b7a1f8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -20,10 +20,11 @@ Tejash Shah, 2019-2020 Xiaoyan Zhou, 2020 [Jianfeng Yan](https://github.com/j4yan), 2021-2022 - +[Jun Liu](https://github.com/junliume), 2021-2024 ## Product Manager -[Jun Liu](https://github.com/junliume) +[John Afaganis](https://github.com/afagaj) + ## Contributors diff --git a/README.md b/README.md index c316a0a322..29d3d4e85a 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ```bash make -j install ``` + **[See Note on -j](#notes)** ## Optional post-install steps @@ -146,7 +147,8 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` -Note the `-j` option for building with multiple threads in parallel, which speeds up the build significantly. +### Notes +The `-j` option for building with multiple threads in parallel, which speeds up the build significantly. However, `-j` launches unlimited number of threads, which can cause the build to run out of memory and crash. On average, you should expect each thread to use ~2Gb of RAM. Depending on the number of CPU cores and the amount of RAM on your system, you may want to @@ -211,4 +213,4 @@ script/uninstall_precommit.sh ``` If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the -`git commit` command. +`git commit` command. \ No newline at end of file diff --git a/docs/reference/Supported_Primitives_Guide.rst b/docs/conceptual/Composable-Kernel-math.rst similarity index 85% rename from docs/reference/Supported_Primitives_Guide.rst rename to docs/conceptual/Composable-Kernel-math.rst index e24acf5656..1c21fd8a11 100644 --- a/docs/reference/Supported_Primitives_Guide.rst +++ b/docs/conceptual/Composable-Kernel-math.rst @@ -1,18 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel mathematical basis + :keywords: composable kernel, CK, ROCm, API, mathematics, algorithm .. _supported-primitives: ******************************************************************** -Supported Primitives Guide +Composable Kernel mathematical basis ******************************************************************** -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. +This is an introduction to the math which underpins the algorithms implemented in Composable Kernel. ------------- -Softmax ------------- For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, diff --git a/docs/conceptual/Composable-Kernel-structure.rst b/docs/conceptual/Composable-Kernel-structure.rst new file mode 100644 index 0000000000..43c3603b95 --- /dev/null +++ b/docs/conceptual/Composable-Kernel-structure.rst @@ -0,0 +1,29 @@ +.. meta:: + :description: Composable Kernel structure + :keywords: composable kernel, CK, ROCm, API, structure + +.. _what-is-ck: + +******************************************************************** +Composable Kernel structure +******************************************************************** + +The Composable Kernel library uses a tile-based programming model and tensor coordinate transformation to achieve performance portability and code maintainability. Tensor coordinate transformation is a complexity reduction technique for complex machine learning operators. + + +.. image:: ../data/ck_component.png + :alt: CK Components + + +The Composable Kernel library consists of four layers: + +* a templated tile operator layer +* a templated kernel and invoker layer +* an instantiated kernel and invoker layer +* a client API layer. + +A wrapper component is included to simplify tensor transform operations. + +.. image:: ../data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/conceptual/what-is-ck.rst b/docs/conceptual/what-is-ck.rst deleted file mode 100644 index 36785fc6ca..0000000000 --- a/docs/conceptual/what-is-ck.rst +++ /dev/null @@ -1,41 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _what-is-ck: - -******************************************************************** -What is the Composable Kernel library -******************************************************************** - - -Methodology -=========== - -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. - -CK utilizes two concepts to achieve performance portability and code maintainability: - -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators using an innovative technique called - "Tensor Coordinate Transformation". - -.. image:: ../data/ck_component.png - :alt: CK Components - - -Code Structure -============== - -The CK library is structured into 4 layers: - -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer - -It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. - -.. image:: ../data/ck_layer.png - :alt: CK Layers - \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 30ef672f84..82e4c48001 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,30 +8,33 @@ Composable Kernel User Guide ******************************************************************** -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. +The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. -The CK documentation is structured as follows: +The Composable Kernel repository is located at `https://github.com/ROCm/composable-kernel `_. .. grid:: 2 :gutter: 3 - .. grid-item-card:: Installation + .. grid-item-card:: Install - * :ref:`docker-hub` + * :doc:`Composable Kernel prerequisites <./install/Composable-Kernel-prerequisites>` + * :doc:`Build and install Composable Kernel <./install/Composable-Kernel-install>` + * :doc:`Build and install Composable Kernel on a Docker image <./install/Composable-Kernel-Docker>` .. grid-item-card:: Conceptual - * :ref:`what-is-ck` + * :doc:`Composable Kernel structure <./conceptual/Composable-Kernel-structure>` + * :doc:`Composable Kernel mathematical basis <./conceptual/Composable-Kernel-math>` - .. grid-item-card:: API reference + .. grid-item-card:: Tutorials + + * :doc:`Composable Kernel examples and tests <./tutorial/Composable-Kernel-examples>` + + .. grid-item-card:: Reference - * :ref:`supported-primitives` * :ref:`api-reference` * :ref:`wrapper` - .. grid-item-card:: Tutorial - - * :ref:`hello-world` To contribute to the documentation refer to `Contributing to ROCm `_. diff --git a/docs/install/Composable-Kernel-Docker.rst b/docs/install/Composable-Kernel-Docker.rst new file mode 100644 index 0000000000..d40cc2bff5 --- /dev/null +++ b/docs/install/Composable-Kernel-Docker.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel docker files + :keywords: composable kernel, CK, ROCm, API, docker + +.. _docker-hub: + +******************************************************************** +Composable Kernel Docker containers +******************************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The images also contain `ROCm `_, `CMake `_, and the `ROCm LLVM compiler infrastructure `_. + +Composable Kernel Docker images are named according to their operating system and ROCm version. For example, a Docker image named ``ck_ub22.04_rocm6.3`` would correspond to an Ubuntu 22.04 image with ROCm 6.3. + diff --git a/docs/install/Composable-Kernel-install.rst b/docs/install/Composable-Kernel-install.rst new file mode 100644 index 0000000000..61b1fe0fcb --- /dev/null +++ b/docs/install/Composable-Kernel-install.rst @@ -0,0 +1,72 @@ +.. meta:: + :description: Composable Kernel build and install + :keywords: composable kernel, CK, ROCm, API, documentation, install + +****************************************************** +Building and installing Composable Kernel with CMake +****************************************************** + +Before you begin, clone the `Composable Kernel GitHub repository `_ and create a ``build`` directory in its root: + +.. code:: shell + + git clone https://github.com/ROCm/composable_kernel.git + cd composable_kernel + mkdir build + +Change directory to the ``build`` directory and generate the makefile using the ``cmake`` command. Two build options are required: + +* ``CMAKE_PREFIX_PATH``: The ROCm installation path. ROCm is installed in ``/opt/rocm`` by default. +* ``CMAKE_CXX_COMPILER``: The path to the Clang compiler. Clang is found at ``/opt/rocm/llvm/bin/clang++`` by default. + + +.. code:: shell + + cd build + cmake ../. -D CMAKE_PREFIX_PATH="/opt/rocm" -D CMAKE_CXX_COMPILER="/opt/rocm/llvm/bin/clang++" [-D [-D] ...] + + +Other build options are: + +* ``DISABLE_DL_KERNELS``: Set this to "ON" to not build deep learning (DL) and data parallel primitive (DPP) instances. + + .. note:: + + DL and DPP instances are useful on architectures that don't support XDL or WMMA. + +* ``CK_USE_FP8_ON_UNSUPPORTED_ARCH``: Set to ``ON`` to build FP8 data type instances on gfx90a without native FP8 support. +* ``GPU_TARGETS``: Target architectures. Target architectures in this list must all be different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx90a"``. This option is required to build tests and examples. +* ``GPU_ARCHS``: Target architectures. Target architectures in this list are not limited to different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx1100"``. +* ``CMAKE_BUILD_TYPE``: The build type. Can be ``None``, ``Release``, ``Debug``, ``RelWithDebInfo``, or ``MinSizeRel``. CMake will use ``Release`` by default. + +.. Note:: + + If neither ``GPU_TARGETS`` nor ``GPU_ARCHS`` is specified, Composable Kernel will be built for all targets supported by the compiler. + +Build Composable Kernel using the generated makefile. This will build the library, the examples, and the tests, and save them to ``bin``. + +.. code:: shell + + make -j20 + +The ``-j`` option speeds up the build by using multiple threads in parallel. For example, ``-j20`` uses twenty threads in parallel. On average, each thread will use 2GB of memory. Make sure that the number of threads you use doesn't exceed the available memory in your system. + +Using ``-j`` alone will launch an unlimited number of threads and is not recommended. + +Install the Composable Kernel library: + +.. code:: shell + + make install + +After running ``make install``, the Composable Kernel files will be saved to the following locations: + +* Library files: ``/opt/rocm/lib/`` +* Header files: ``/opt/rocm/include/ck/`` and ``/opt/rocm/include/ck_tile/`` +* Examples, tests, and ckProfiler: ``/opt/rocm/bin/`` + +For information about ckProfiler, see `the ckProfiler readme file `_. + +For information about running the examples and tests, see :doc:`Composable Kernel examples and tests <../tutorial/Composable-Kernel-examples>`. + + diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst new file mode 100644 index 0000000000..10be849ea6 --- /dev/null +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -0,0 +1,32 @@ +.. meta:: + :description: Composable Kernel prerequisites + :keywords: composable kernel, CK, ROCm, API, documentation, prerequisites + +****************************************************** +Composable Kernel prerequisites +****************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The following prerequisites are required to build and install Composable Kernel: + +* cmake +* hip-rocclr +* iputils-ping +* jq +* libelf-dev +* libncurses5-dev +* libnuma-dev +* libpthread-stubs0-dev +* llvm-amdgpu +* mpich +* net-tools +* python3 +* python3-dev +* python3-pip +* redis +* rocm-llvm-dev +* zlib1g-dev +* libzstd-dev +* openssh-server +* clang-format-12 diff --git a/docs/install/dockerhub.rst b/docs/install/dockerhub.rst deleted file mode 100644 index 87eb5a4f81..0000000000 --- a/docs/install/dockerhub.rst +++ /dev/null @@ -1,101 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _docker-hub: - -******************************************************************** -CK Docker Hub -******************************************************************** - -Why do I need this? -=================== - -To make things simpler, and bring Composable Kernel and its dependencies together, -docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. - -Refer to `Docker Overview `_ for more information on Docker images and containers. - -Which image is right for me? -============================ - -The image naming includes information related to the docker image. -For example ``ck_ub20.04_rocm6.0`` indicates the following: - -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm6.0`` - ROCm platform version 6.0. - -Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: - - docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 - - -What is inside the image? -------------------------- - -The docker images have everything you need for running CK including: - -* `ROCm `_ -* `CMake `_ -* `Compiler `_ -* `Composable Kernel library `_ - -Running the docker container -============================ - -After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm6.0 \ - /bin/bash - -After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: - -.. note:: - - If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. - -Stop and restart the docker image -================================= - -After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. - -Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. - -To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: - - docker exec -it bash - -Where: - -* `exec` is the docker command -* `-it` is the interactive option for `exec` -* `` specifies an active container on the system -* `bash` specifies the command to run in the interactive shell - -.. note:: - - You can use the ``docker container ls`` command to list the active containers on the system. - -To start a container from the image, use the ``docker start`` command:: - - docker start - -Then use the docker exec command as shown above to start the bash shell. - -Use the ``docker stop`` command to stop the container and restore the image to its initial state:: - - docker stop - -Editing the docker image -======================= - -If you want to customize the docker image, edit the -`Dockerfile `_ -from the GitHub repository to suit your needs. diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/Composable-Kernel-API-reference.rst similarity index 79% rename from docs/reference/API_Reference_Guide.rst rename to docs/reference/Composable-Kernel-API-reference.rst index 0d2d41c1eb..b6ee9f7790 100644 --- a/docs/reference/API_Reference_Guide.rst +++ b/docs/reference/Composable-Kernel-API-reference.rst @@ -5,26 +5,20 @@ .. _api-reference: ******************************************************************** -API reference guide +Composable Kernel API reference guide ******************************************************************** - -This document contains details of the APIs for the Composable Kernel (CK) library and introduces -some of the key design principles that are used to write new classes that extend CK functionality. +This document contains details of the APIs for the Composable Kernel library and introduces some of the key design principles that are used to write new classes that extend the functionality of the Composable Kernel library. ================= -CK Datatypes -================= - ------------------ DeviceMem ------------------ +================= .. doxygenstruct:: DeviceMem ---------------------------- +============================= Kernels For Flashattention ---------------------------- +============================= The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists the classes that are used in the CK GPU implementation of Flashattention. diff --git a/docs/reference/wrapper.rst b/docs/reference/Composable-Kernel-wrapper.rst similarity index 88% rename from docs/reference/wrapper.rst rename to docs/reference/Composable-Kernel-wrapper.rst index 190fbcd445..4baa8d2b64 100644 --- a/docs/reference/wrapper.rst +++ b/docs/reference/Composable-Kernel-wrapper.rst @@ -1,20 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel wrapper + :keywords: composable kernel, CK, ROCm, API, wrapper .. _wrapper: ******************************************************************** -Wrapper +Composable Kernel wrapper ******************************************************************** -------------------------------------- -Description -------------------------------------- - -The CK library provides a lightweight wrapper for more complex operations implemented in -the library. +The Composable Kernel library provides a lightweight wrapper to simplify the more complex operations. Example: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 533b81cd39..ab82b7deb1 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -3,34 +3,38 @@ defaults: root: index subtrees: -- caption: Conceptual - entries: - - file: conceptual/what-is-ck.rst - title: What is Composable Kernel? - - caption: Install entries: - - file: install/dockerhub.rst - title: Docker Hub - -- caption: CK API Reference + - file: install/Composable-Kernel-prerequisites.rst + title: Composable Kernel prerequisites + - file: install/Composable-Kernel-install.rst + title: Build and install Composable Kernel + - file: install/Composable-Kernel-Docker.rst + title: Composable Kernel Docker images + +- caption: Conceptual entries: - - file: reference/Supported_Primitives_Guide.rst - title: Supported Primitives - - file: reference/API_Reference_Guide.rst - title: API Reference - - file: reference/wrapper.rst - title: Wrapper + - file: conceptual/Composable-Kernel-structure.rst + title: Composable Kernel structure + - file: conceptual/Composable-Kernel-math.rst + title: Composable Kernel mathematical basis - caption: Tutorial entries: - - file: tutorial/tutorial_hello_world.rst - title: Hello World Tutorial + - file: tutorial/Composable-Kernel-examples.rst + title: Composable Kernel examples + +- caption: Reference + entries: + - file: reference/Composable-Kernel-API-reference.rst + title: Composable Kernel API reference + - file: reference/Composable-Kernel-wrapper.rst + title: Composable Kernel Wrapper - caption: About entries: - file: Contributors_Guide.rst - title: Contributing to CK + title: Contributing to Composable Kernel - file: license.rst title: License \ No newline at end of file diff --git a/docs/tutorial/Composable-Kernel-examples.rst b/docs/tutorial/Composable-Kernel-examples.rst new file mode 100644 index 0000000000..62422d6f15 --- /dev/null +++ b/docs/tutorial/Composable-Kernel-examples.rst @@ -0,0 +1,40 @@ +.. meta:: + :description: Composable Kernel examples and tests + :keywords: composable kernel, CK, ROCm, API, examples, tests + +******************************************************************** +Composable Kernel examples and tests +******************************************************************** + +After :doc:`building and installing Composable Kernel <../install/Composable-Kernel-install>`, the examples and tests will be moved to ``/opt/rocm/bin/``. + +All tests have the prefix ``test`` and all examples have the prefix ``example``. + +Use ``ctest`` with no arguments to run all examples and tests, or use ``ctest -R`` to run a single test. For example: + +.. code:: shell + + ctest -R test_gemm_fp16 + +Examples can be run individually as well. For example: + +.. code:: shell + + ./bin/example_gemm_xdl_fp16 1 1 1 + +For instructions on how to run individual examples and tests, see their README files in the |example|_ and |test|_ GitHub folders. + +To run smoke tests, use ``make smoke``. + +To run regression tests, use ``make regression``. + +In general, tests that run for under thirty seconds are included in the smoke tests and tests that run for over thirty seconds are included in the regression tests. + +.. |example| replace:: ``example`` +.. _example: https://github.com/ROCm/composable_kernel/tree/develop/example + +.. |client_example| replace:: ``client_example`` +.. _client_example: https://github.com/ROCm/composable_kernel/tree/develop/client_example + +.. |test| replace:: ``test`` +.. _test: https://github.com/ROCm/composable_kernel/tree/develop/test \ No newline at end of file diff --git a/docs/tutorial/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst deleted file mode 100644 index c31460785b..0000000000 --- a/docs/tutorial/tutorial_hello_world.rst +++ /dev/null @@ -1,165 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _hello-world: - -******************************************************************** -Hello World Tutorial -******************************************************************** - -This tutorial is for engineers dealing with artificial intelligence and machine learning who -would like to optimize pipelines and improve performance using the Composable -Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. - -Description -=========== - -Modern AI technology solves more and more problems in a variety of fields, but crafting fast and -efficient workflows is still challenging. CK can make the AI workflow fast -and efficient. CK is a collection of optimized AI operator kernels with tools to create -new kernels. The library has components required for modern neural network architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. - -CK library acceleration features are based on: - -* Layered structure -* Tile-based computation model -* Tensor coordinate transformation -* Hardware acceleration use -* Support of low precision data types including fp16, bf16, int8 and int4 - -If you need more technical details and benchmarking results read the following -`blog post `_. - -To download the library visit the `composable_kernel repository `_. - -Hardware targets -================ - -CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are -supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. - -========== ========= -GPU Target AMD GPU -========== ========= -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT -========== ========= - -There are also `cloud options `_ you can find if -you don't have an AMD GPU at hand. - -Build the library -================= - -This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. - -.. note:: - - You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. - -Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: - - cd composable_kernel/ - -Create and change to a ``build`` directory:: - - mkdir build && cd build - -The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: - - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D BUILD_DEV=OFF \ - -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. - -If everything goes well the CMake command will return:: - - -- Configuring done - -- Generating done - -- Build files have been written to: "/root/workspace/composable_kernel/build" - -Finally, you can build examples and tests:: - - make -j examples tests - -When complete you should see:: - - Scanning dependencies of target tests - [100%] Built target tests - -Run examples and tests -====================== - -Examples are listed as test cases as well, so you can run all examples and tests with:: - - ctest - -You can check the list of all tests by running:: - - ctest -N - -You can also run examples separately as shown in the following example execution:: - - ./bin/example_gemm_xdl_fp16 1 1 1 - -The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. - -If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 - -However, running it on a `gfx1030` device should result in the following:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem - -Don't worry, some operators are supported on `gfx1030` architecture, so you can run a -separate example like:: - - ./bin/example_gemm_dl_fp16 1 1 1 - -and it should return something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} - arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} - arg.c_grid_desc_m_n_{ 3840, 4096} - launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> - -.. note:: - - A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. - -You can also run a separate test:: - - ctest -R test_gemm_fp16 - -If everything goes well you should see something like:: - - Start 121: test_gemm_fp16 - 1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - - 100% tests passed, 0 tests failed out of 1 - -Summary -======= - -In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. - -P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. From d142e15f5e18f9c9cfa66d1de6479d8f2583827d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 27 Mar 2025 18:48:47 -0700 Subject: [PATCH 003/443] add gfx950 to default targets for rocm6.4+ (#2032) --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bb0c254e06..4c1ca789f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,8 +167,10 @@ if(NOT ENABLE_ASAN_PACKAGING) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") - else() + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600300000 AND ${hip_VERSION_FLAT} LESS 600400000) set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950") endif() else() #build CK only for xnack-supported targets when using ASAN From a82f338fb9fb5743f071c5e6831c3dd92fcd7982 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 28 Mar 2025 11:31:52 +0800 Subject: [PATCH 004/443] hotfix fix sorting int64 (#2025) * fix sorting int64 * clang format * fix example issue * update WA issue # --------- Co-authored-by: coderfeli Co-authored-by: carlushuang --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 4 +-- .../15_fused_moe/instances/fused_moe_api.cpp | 29 ++++++++++--------- include/ck_tile/core/config.hpp | 4 +++ .../fused_moe/kernel/moe_sorting_kernel.hpp | 18 ++++++++---- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index f00d948f25..e59fcaedad 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -74,7 +74,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int topk = args.get_int("k"); int seed = args.get_int("seed"); int unit_size = args.get_int("unit"); - int moe_buf_size = args.get_int("moe_buf_size"); + int64_t moe_buf_size = static_cast(args.get_uint64("moe_buf_size")); int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); @@ -175,7 +175,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) unit_size, num_experts, topk, - static_cast(moe_buf_size * sizeof(float))}; + static_cast(moe_buf_size * sizeof(float))}; ck_tile::stream_config sc{nullptr, true, diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 466420f066..f887d57aa9 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -19,20 +19,21 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; auto a0 = fused_moesorting_args{ - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.local_expert_mask_ptr, // const void* p_local_expert_mask; - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.ws_ptr, // void* p_ws; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; - a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + static_cast(a.num_tokens) * a.stride_token * + o_data_bytes // index_t moe_buf_bytes; }; auto t1 = fused_moegemm_traits{t.prec_i, diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index eeaf0dca6f..b1d201e30e 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -260,3 +260,7 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) #define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0 #endif #endif + +#ifndef CK_TILE_WA_ISSUE_2028 +#define CK_TILE_WA_ISSUE_2028 1 +#endif diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index a1410d1f4f..6a7ccd2472 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -192,7 +192,7 @@ struct MoeSortingHostArgs index_t unit_size; // this is the M_a of fused-moe kernel index_t num_experts; index_t topk; - index_t moe_buf_bytes; // byte size of p_moe_buf + long_index_t moe_buf_bytes; // byte size of p_moe_buf }; template @@ -219,7 +219,7 @@ struct MoeSortingKernel void* p_moe_buf; index_t tokens; index_t num_experts; - index_t moe_buf_bytes; + long_index_t moe_buf_bytes; index_t tokens_per_thread; index_t smem_rows; @@ -426,7 +426,7 @@ struct MoeSortingKernel return row * total_col + col; } - CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const + CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes) const { const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x; if(offset < buf_bytes / 16) @@ -1218,10 +1218,10 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) } template -CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, index_t gid) +CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_bytes, index_t gid) { // const index_t offset = (blockIdx.x - 1) * BLOCK_SIZE + threadIdx.x; - index_t offset = gid * BLOCK_SIZE + threadIdx.x; + long_index_t offset = static_cast(gid) * BLOCK_SIZE + threadIdx.x; if(offset < buf_bytes / 16) { buf[offset] = uint8x16_t{0}; @@ -1233,6 +1233,12 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes, // prefer to run mp kernel if is not oneshot CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) { +#if CK_TILE_WA_ISSUE_2028 + if(tokens_ >= 65536 * 2) + { + return true; + } +#endif auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_); bool is_sub_token_onshot = tokens_ <= sub_token_; return is_sub_token_onshot; @@ -1523,7 +1529,7 @@ struct MoeSortingMultiPhaseKernel_P2 index_t num_experts; index_t mesh_stride; // mesh_stride for p_expert_mesh mdiv unit_size_mdiv; - index_t moe_buf_bytes; + long_index_t moe_buf_bytes; }; CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) From 8a20b62e9124c10f4f240dce2c312b0a332bce6c Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 28 Mar 2025 21:58:06 +0800 Subject: [PATCH 005/443] Reduce redundant space in bias tensor (#2024) Co-authored-by: Po Yen Chen --- example/ck_tile/01_fmha/fmha_fwd.cpp | 12 ++++++------ include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..8f6fb8df54 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -620,7 +620,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor alibi_slope_host( @@ -884,7 +884,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); @@ -909,7 +909,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); @@ -925,7 +925,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); @@ -1381,9 +1381,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index a578f0c2f4..1202524950 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -983,7 +983,7 @@ struct FmhaFwdKernel } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias + key_start; + batch_offset_bias = query_start * kargs.stride_bias; } if constexpr(kStoreLSE) { From fc073b483e03caa7377e56a6b8b3573054e031fa Mon Sep 17 00:00:00 2001 From: Adel Johar Date: Fri, 28 Mar 2025 15:12:27 +0100 Subject: [PATCH 006/443] Docs: Add precision support reference page (#1973) * Docs: Add precision support reference page * edit of the precision type content * added more description on scalars --------- Co-authored-by: spolifroni-amd Co-authored-by: Aviral Goel --- .gitignore | 2 + docs/index.rst | 6 +- .../Composable_Kernel_custom_types.rst | 39 +++++++++++ ...mposable_Kernel_supported_scalar_types.rst | 69 +++++++++++++++++++ .../Composable_Kernel_vector_utilities.rst | 16 +++++ docs/sphinx/_toc.yml.in | 9 ++- 6 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 docs/reference/Composable_Kernel_custom_types.rst create mode 100644 docs/reference/Composable_Kernel_supported_scalar_types.rst create mode 100644 docs/reference/Composable_Kernel_vector_utilities.rst diff --git a/.gitignore b/.gitignore index f4d5ff7abd..599ef99e35 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,8 @@ _static/ _templates/ _toc.yml _doxygen/ +docs/doxygen/html +docs/doxygen/xml # JetBrains IDE .idea/ diff --git a/docs/index.rst b/docs/index.rst index 82e4c48001..15a9321d43 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,7 +8,7 @@ Composable Kernel User Guide ******************************************************************** -The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. +The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. The Composable Kernel repository is located at `https://github.com/ROCm/composable-kernel `_. @@ -32,10 +32,12 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab .. grid-item-card:: Reference + * :doc:`Composable Kernel supported scalar types <./reference/Composable_Kernel_supported_scalar_types>` + * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` + * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` * :ref:`api-reference` * :ref:`wrapper` - To contribute to the documentation refer to `Contributing to ROCm `_. You can find licensing information on the `Licensing `_ page. diff --git a/docs/reference/Composable_Kernel_custom_types.rst b/docs/reference/Composable_Kernel_custom_types.rst new file mode 100644 index 0000000000..863d4131b9 --- /dev/null +++ b/docs/reference/Composable_Kernel_custom_types.rst @@ -0,0 +1,39 @@ +.. meta:: + :description: Composable Kernel supported custom types + :keywords: composable kernel, custom, data types, support, CK, ROCm + +****************************************************** +Composable Kernel custom data types +****************************************************** + +Composable Kernel supports the use of custom types that provide a way to implement specialized numerical formats. + +To use custom types, a C++ type that implements the necessary operations for tensor computations needs to be created. These should include: + +* Constructors and initialization methods +* Arithmetic operators if the type will be used in computational operations +* Any conversion functions needed to interface with other parts of an application + +For example, to create a complex half-precision type: + +.. code:: cpp + + struct complex_half_t + { + half_t real; + half_t img; + }; + + struct complex_half_t + { + using type = half_t; + type real; + type img; + + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + +Custom types can be particularly useful for specialized applications such as complex number arithmetic, +custom quantization schemes, or domain-specific number representations. + diff --git a/docs/reference/Composable_Kernel_supported_scalar_types.rst b/docs/reference/Composable_Kernel_supported_scalar_types.rst new file mode 100644 index 0000000000..7ea1a9eaeb --- /dev/null +++ b/docs/reference/Composable_Kernel_supported_scalar_types.rst @@ -0,0 +1,69 @@ +.. meta:: + :description: Composable Kernel supported scalar types + :keywords: composable kernel, scalar, data types, support, CK, ROCm + +*************************************************** +Composable Kernel supported scalar data types +*************************************************** + +The Composable Kernel library provides support for the following scalar data types: + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Type + - Bit Width + - Description + + * - ``double`` + - 64-bit + - Standard IEEE 754 double precision floating point + + * - ``float`` + - 32-bit + - Standard IEEE 754 single precision floating point + + * - ``int32_t`` + - 32-bit + - Standard signed 32-bit integer + + * - ``int8_t`` + - 8-bit + - Standard signed 8-bit integer + + * - ``uint8_t`` + - 8-bit + - Standard unsigned 8-bit integer + + * - ``bool`` + - 1-bit + - Boolean type + + * - ``ck::half_t`` + - 16-bit + - IEEE 754 half precision floating point with 5 exponent bits, 10 mantissa bits, and 1 sign bit + + * - ``ck::bhalf_t`` + - 16-bit + - Brain floating point with 8 exponent bits, 7 mantissa bits, and 1 sign bit + + * - ``ck::f8_t`` + - 8-bit + - 8-bit floating point (E4M3 format) with 4 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf8_t`` + - 8-bit + - 8-bit brain floating point (E5M2 format) with 5 exponent bits, 2 mantissa bits, and 1 sign bit + + * - ``ck::f4_t`` + - 4-bit + - 4-bit floating point format (E2M1 format) with 2 exponent bits, 1 mantissa bit, and 1 sign bit + + * - ``ck::f6_t`` + - 6-bit + - 6-bit floating point format (E2M3 format) with 2 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf6_t`` + - 6-bit + - 6-bit brain floating point format (E3M2 format) with 3 exponent bits, 2 mantissa bits, and 1 sign bit \ No newline at end of file diff --git a/docs/reference/Composable_Kernel_vector_utilities.rst b/docs/reference/Composable_Kernel_vector_utilities.rst new file mode 100644 index 0000000000..3103653191 --- /dev/null +++ b/docs/reference/Composable_Kernel_vector_utilities.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel supported precision types and custom type support + :keywords: composable kernel, precision, data types, ROCm + +****************************************************** +Composable Kernel vector template utilities +****************************************************** + +Composable Kernel includes template utilities for creating vector types with customizable widths. These template utilities also flatten nested vector types into a single, wider vector, preventing the creation of vectors of vectors. + +Vectors composed of supported scalar and custom types can be created with the ``ck::vector_type`` template. + +For example, ``ck::vector_type`` creates a vector composed of four floats and ``ck::vector_type`` creates a vector composed of eight half-precision scalars. + +For vector operations to be valid, the underlying types must be either a :doc:`supported scalar type ` or :doc:`a custom type ` that implements the required operations. + diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index ab82b7deb1..df98998224 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -11,7 +11,7 @@ subtrees: title: Build and install Composable Kernel - file: install/Composable-Kernel-Docker.rst title: Composable Kernel Docker images - + - caption: Conceptual entries: - file: conceptual/Composable-Kernel-structure.rst @@ -26,6 +26,12 @@ subtrees: - caption: Reference entries: + - file: reference/Composable_Kernel_supported_scalar_types.rst + title: Composable Kernel scalar types + - file: reference/Composable_Kernel_custom_types.rst + title: Composable Kernel custom types + - file: reference/Composable_Kernel_vector_utilities.rst + title: Composable Kernel vector utilities - file: reference/Composable-Kernel-API-reference.rst title: Composable Kernel API reference - file: reference/Composable-Kernel-wrapper.rst @@ -37,4 +43,3 @@ subtrees: title: Contributing to Composable Kernel - file: license.rst title: License - \ No newline at end of file From 16b15e336a13e60f54ac9ea03975b9cf44b1d6f3 Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Mon, 31 Mar 2025 09:20:52 -0700 Subject: [PATCH 007/443] Fix gemm universal and grouped_conv_fwd test failures on gfx950 (#2031) --- .../device_grouped_conv_fwd_xdl_comp_instance.hpp | 5 ++++- .../device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp | 2 +- .../device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index e7bbf8a26a..f491474d38 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -84,7 +84,6 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, @@ -94,6 +93,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< // clang-format on >; +// instances not working on gfx950 template using device_grouped_conv_fwd_xdl_bf16_comp_instances_part2 = std::tuple< // clang-format off + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -143,6 +144,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< // clang-format on >; +// instances not working on gfx950 template ; +// instances not working on gfx950 template , S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, @@ -62,6 +61,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2 = std::tuple< // clang-format off + DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 8, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 50fdca9348..9f142ad831 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -44,7 +44,6 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, @@ -55,6 +54,7 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< // clang-format off + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> // clang-format on From dd4c12b155c6eece31e851b3aa46939d00f6adbd Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Mon, 31 Mar 2025 19:30:17 -0700 Subject: [PATCH 008/443] f8/bf16 GEMM Stream-K (#1879) --- CHANGELOG.md | 2 +- example/01_gemm/CMakeLists.txt | 6 + .../01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp | 64 +++++++ .../gpu/gemm_universal_streamk.hpp | 129 ++++++++++++- .../gpu/gemm_universal_streamk/CMakeLists.txt | 23 ++- ...versal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp | 2 + ...bf16_mk_nk_mn_comp_mnkpadding_instance.cpp | 31 +++ ..._bf16_mk_nk_mn_comp_mnpadding_instance.cpp | 30 +++ ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 31 +++ ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 31 +++ ..._universal_streamk_f8_f8_bf16_mk_kn_mn.hpp | 99 ++++++++++ ...f8_bf16_mk_kn_mn_comp_default_instance.cpp | 24 +++ ...8_bf16_mk_kn_mn_comp_kpadding_instance.cpp | 24 +++ ..._bf16_mk_kn_mn_comp_nkpadding_instance.cpp | 24 +++ ..._bf16_mk_kn_mn_mem_v1_default_instance.cpp | 25 +++ ...bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 25 +++ ...f16_mk_kn_mn_mem_v1_nkpadding_instance.cpp | 25 +++ ..._bf16_mk_kn_mn_mem_v2_default_instance.cpp | 25 +++ ...bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 25 +++ ...f16_mk_kn_mn_mem_v2_nkpadding_instance.cpp | 25 +++ ..._universal_streamk_f8_f8_bf16_mk_nk_mn.hpp | 107 +++++++++++ ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 24 +++ ...8_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 24 +++ ..._bf16_mk_nk_mn_mem_v1_default_instance.cpp | 25 +++ ...bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 25 +++ ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 25 +++ ...bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 25 +++ .../profile_gemm_universal_streamk_impl.hpp | 38 ++-- .../src/profile_gemm_universal_streamk.cpp | 44 +++-- test/CMakeLists.txt | 4 + test/gemm_universal_streamk/CMakeLists.txt | 15 ++ ...t_gemm_universal_streamk_ut_cases_bf16.inc | 177 ++++++++++++++++++ ...t_gemm_universal_streamk_ut_cases_fp16.inc | 113 +++++++++++ ...st_gemm_universal_streamk_ut_cases_fp8.inc | 113 +++++++++++ .../test_gemm_universal_streamk_util.hpp | 104 ++++++++++ .../test_gemm_universal_streamk_xdl_bf16.cpp | 85 +++++++++ .../test_gemm_universal_streamk_xdl_fp16.cpp | 84 +++++++++ .../test_gemm_universal_streamk_xdl_fp8.cpp | 74 ++++++++ 38 files changed, 1738 insertions(+), 38 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp mode change 100644 => 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp create mode 100755 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp mode change 100644 => 100755 profiler/src/profile_gemm_universal_streamk.cpp mode change 100644 => 100755 test/CMakeLists.txt create mode 100755 test/gemm_universal_streamk/CMakeLists.txt create mode 100644 test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc create mode 100644 test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc create mode 100755 test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc create mode 100644 test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp create mode 100755 test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_bf16.cpp create mode 100644 test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp16.cpp create mode 100755 test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp8.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d07abfc24..de831a6898 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). - +* Added support for Stream-K version of mixed fp8/bf16 GEMM ### Optimized None diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index ee9f959d94..96678d275a 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -28,8 +28,14 @@ add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) + add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) + + +add_example_executable(example_gemm_xdl_fp16_fp8_streamk_v3 gemm_xdl_fp16_fp8_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3) + add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp new file mode 100644 index 0000000000..bd38eb17ee --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 16, 16, + 256, 8, 16, + 16, 16, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp index 18203e7d5c..372e744bd7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp @@ -635,7 +635,7 @@ void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadd PassThrough>>>& instances); #endif -#if(defined(CK_ENABLE_FP8)) +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances( std::vector>>& @@ -834,6 +834,83 @@ void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding instances); #endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif + template && is_same_v && is_same_v) { @@ -1141,6 +1218,54 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + } + } +#endif return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt old mode 100644 new mode 100755 index e1612bcd24..b7391d3446 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt @@ -21,9 +21,7 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp - device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp - device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -44,7 +42,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp - device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -65,7 +62,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp - device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp @@ -101,6 +97,21 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp - device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp) - + device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +) add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 209d8f644e..959c1c0992 100755 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -51,8 +51,10 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // AGPR Spill // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100755 index 0000000000..a16d3988fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100755 index 0000000000..3716b46f6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100755 index 0000000000..00ed1698dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100755 index 0000000000..bee03061a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp new file mode 100755 index 0000000000..5bf5c01b97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmNKPadding = GemmSpecialization::NKPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + //Only enable these instances on gfx94x + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 8, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 4, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 4, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 128, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 4, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 8, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> +#endif + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..689c2bbbec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..149b830a83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp new file mode 100644 index 0000000000..db5082f25c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..cd2ad4f654 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..1ed170785b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp new file mode 100644 index 0000000000..9e28c16191 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..85dc38fbe4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..2f188ac939 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp new file mode 100644 index 0000000000..94684921c7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp new file mode 100755 index 0000000000..540b90e54b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +template +using device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Latency friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + // Memory friendly + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> +#endif + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..df07e21eef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..22ffb264b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..d5e84297d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..314aec027a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..eb0c871a04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 0000000000..df92ed71c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f8_bf16/device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_streamk_f8_f8_bf16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp index 72194e8e61..d145ab1766 100644 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -11,6 +11,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp" @@ -20,12 +21,14 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" namespace ck { namespace profiler { template a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); int rotating_count = std::max( @@ -103,6 +108,9 @@ bool profile_gemm_universal_streamk_impl(int do_verification, DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); + a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); @@ -125,21 +133,22 @@ bool profile_gemm_universal_streamk_impl(int do_verification, // Run reference GEMM if(do_verification) { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( + // Use CPU validation + // Note: GPU validation is not supported for fp8 !!! + using ReferenceGemmInstanceCPU = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm_cpu = ReferenceGemmInstanceCPU{}; + auto ref_invoker_cpu = ref_gemm_cpu.MakeInvoker(); + auto ref_argument_cpu = ref_gemm_cpu.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); + ref_invoker_cpu.Run(ref_argument_cpu); } std::string best_op_name; @@ -157,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size == -1) + if(Grid_size != -1) { grid_size_list = {Grid_size}; } @@ -203,6 +212,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + // Always compare against CPU reference results computed earlier pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); if(do_log) diff --git a/profiler/src/profile_gemm_universal_streamk.cpp b/profiler/src/profile_gemm_universal_streamk.cpp old mode 100644 new mode 100755 index b0f66a0c73..4d1ab811ee --- a/profiler/src/profile_gemm_universal_streamk.cpp +++ b/profiler/src/profile_gemm_universal_streamk.cpp @@ -26,6 +26,7 @@ enum struct GemmDataType F8_F16_F16, // 4 F16_F8_F16, // 5 F16_F16_F16_F8, // 6 + F8_F8_BF16, // 7 }; #define OP_NAME "gemm_universal_streamk" @@ -37,7 +38,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " - "comp f8)\n"); + "comp f8; 7: f8->bf16,)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); @@ -112,15 +113,17 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) auto profile = [&](auto a_type, auto b_type, + auto comp_type, auto acc_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { - using ADataType = decltype(a_type); - using BDataType = decltype(b_type); - using AccDataType = decltype(acc_type); - using CDataType = decltype(c_type); + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using ComputeDataType = decltype(comp_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); using ALayout = decltype(a_layout); using BLayout = decltype(b_layout); @@ -132,6 +135,7 @@ int profile_gemm_universal_streamk(int argc, char* argv[]) bool pass = ck::profiler::profile_gemm_universal_streamk_impl Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc new file mode 100644 index 0000000000..b2fdfe8193 --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc new file mode 100755 index 0000000000..b3da08f703 --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp new file mode 100644 index 0000000000..ef3509c0ca --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_universal_streamk_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmUniversal_Streamk : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using CDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + std::vector grid_size_list; + std::vector streamk_sel_list; + + void SetUp() override + { + grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} + // 2:2-tile Stream-K + DP + } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto streamk_sel : streamk_sel_list) + for(auto grid_size : grid_size_list) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int streamk_sel, + int Grid_size, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_universal_streamk_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_bf16.cpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_bf16.cpp new file mode 100755 index 0000000000..1aef74cf18 --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_bf16.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_streamk_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_Streamk_BF16_MK_KN + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_BF16_MK_NK + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_BF16_KM_KN + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_BF16_KM_NK + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_BF16_KM_NK, KernelTypes_KM_NK); + +#include "test_gemm_universal_streamk_ut_cases_bf16.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp16.cpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp16.cpp new file mode 100644 index 0000000000..43b122ff0d --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp16.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_streamk_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_Streamk_FP16_MK_KN + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_FP16_MK_NK + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_FP16_KM_KN + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_FP16_KM_NK + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, +#endif + + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, +#endif + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_streamk_ut_cases_fp16.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp8.cpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp8.cpp new file mode 100755 index 0000000000..3836de056c --- /dev/null +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_xdl_fp8.cpp @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_streamk_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_Streamk_FP8_MK_KN + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_Streamk_FP8_MK_NK + : public ck::test::TestGemmUniversal_Streamk< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_Streamk_FP8_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_streamk_ut_cases_fp8.inc" From 6355ee7ca5c6f3c1a22ee40f58fe6dc956b94242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 1 Apr 2025 16:11:42 +0200 Subject: [PATCH 009/443] Improve compilation time for grouped conv fwd (#2039) * Improve compilation time for grouped conv fwd * Fix --- .../gpu/grouped_convolution_forward.hpp | 12 ++ .../grouped_convolution_forward_comp_xdl.inc | 112 ++++++++++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 8 ++ ...gchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp | 43 +++++++ ...l_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp | 24 ---- ...w_gkcyx_ngkhw_bf16_comp_part2_instance.cpp | 45 +++++++ ...ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp | 43 +++++++ ...dl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp | 24 ---- ...hw_gkcyx_ngkhw_f16_comp_part2_instance.cpp | 45 +++++++ ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 70 +++++++++++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 80 +------------ ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 70 +++++++++++ ...nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp | 70 +++++++++++ ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 80 +------------ ...gc_gkyxc_nhwgk_f16_comp_part2_instance.cpp | 70 +++++++++++ 15 files changed, 590 insertions(+), 206 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index c2e1337737..0b7df6ecfb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -226,6 +226,9 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -39,6 +67,34 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -88,6 +144,34 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -104,6 +188,34 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index b095840a34..c1790901ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -60,10 +60,18 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp # NGCHW, GKCYX, NGKHW xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instance.cpp #dl # GNHWC, GKYXC, GNHWK dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..6cb4ca5652 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp index b055e782c2..7368587c93 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp @@ -32,30 +32,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..7f0feb61d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instances( + std::vector>>&) +{ + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +#endif + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp new file mode 100644 index 0000000000..f9ad6b8212 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp index 13e0e91f97..803de2de55 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp @@ -32,30 +32,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instance.cpp new file mode 100644 index 0000000000..da7949668a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_part2_instances( + std::vector>>&) +{ + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +#endif + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..c078f8ed04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index a344e35c8d..a67b11f1cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -57,84 +57,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwdOddC>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..5c0391a25f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp new file mode 100644 index 0000000000..726276c461 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index 30a8b60bfc..8b7bdec2a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -57,84 +57,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwdOddC>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp new file mode 100644 index 0000000000..c66114b9a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From c59a8bb206d4dc763d07e16f730e563849e68cb6 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 1 Apr 2025 12:06:25 -0700 Subject: [PATCH 010/443] add a fast compilation path for static for (0..N) (#2005) * add a fast compilation path for static for (0..N) * Update functional2.hpp add comment and put range applier into detail namespace * Update functional.hpp ditto for ck-tile * prettify * prettify more * add comment * clang-format --- include/ck/utility/functional2.hpp | 24 +++++++++++++++++++++ include/ck_tile/core/utility/functional.hpp | 24 +++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 99c65f4eb8..a11963cb47 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -46,4 +46,28 @@ struct static_for } }; +namespace detail { + +template +struct applier +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(Number{}), ...); + } +}; + +template // == sizeof...(Is) +using make_applier = __make_integer_seq; + +} // namespace detail + +template +struct static_for<0, N, 1> : detail::make_applier +{ + using detail::make_applier::operator(); +}; + } // namespace ck diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index 2cdce94063..fd0252d3ca 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -58,6 +58,30 @@ struct static_for } }; +namespace detail { + +template +struct applier +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(number{}), ...); + } +}; + +template // == sizeof...(Is) +using make_applier = __make_integer_seq; + +} // namespace detail + +template +struct static_for<0, N, 1> : detail::make_applier +{ + using detail::make_applier::operator(); +}; + struct identity { template From df32020f93880a0086ac10a4e5cdbce47e6a1b41 Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Wed, 2 Apr 2025 04:22:10 +0900 Subject: [PATCH 011/443] Fix Windows build. (#2012) * Remove duplicate using uint64_t. * Cast before shift. --- include/ck/utility/dtype_vector.hpp | 2 -- include/ck_tile/core/utility/magic_div.hpp | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 8f70962fa6..9c40d923d3 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -2000,8 +2000,6 @@ struct vector_type()>> } }; -using int64_t = long; - // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/include/ck_tile/core/utility/magic_div.hpp b/include/ck_tile/core/utility/magic_div.hpp index fd9c733c52..1715983c09 100644 --- a/include/ck_tile/core/utility/magic_div.hpp +++ b/include/ck_tile/core/utility/magic_div.hpp @@ -38,7 +38,7 @@ struct magic_division32_bit_range shift_u32++; }; - uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32; + uint64_t tmp_u64 = static_cast((1UL << shift_u32) - divisor) << 32; uint32_t multiplier_u32 = tmp_u64 / divisor + 1; return make_tuple(multiplier_u32, shift_u32); From ec742908bdae09387e76980af628f7c1125473cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 1 Apr 2025 22:19:35 +0200 Subject: [PATCH 012/443] Grouped conv fwd v3 fix for SplitN an G > 1 (#2038) * Grouped conv fwd v3 fix for SplitN an G > 1 * Remove int8 large test * Retore int8 test --- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 27 +++++++------------ ...est_grouped_convnd_fwd_large_cases_xdl.cpp | 5 +++- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index e91496f6a5..b2f1dbfa5c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -79,15 +79,12 @@ __global__ void [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, - [[maybe_unused]] const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); @@ -141,15 +138,12 @@ __global__ void [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n, - [[maybe_unused]] const index_t groups_count) + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count); - const index_t& num_blocks_per_n = groups_count; - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); @@ -766,7 +760,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); - gdy *= arg.num_group_ * num_workgroups_per_Conv_N; + gdy = arg.num_group_; + gdz = num_workgroups_per_Conv_N; index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); @@ -820,8 +815,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_, - arg.num_group_); + arg.compute_ptr_offset_of_n_); } else { @@ -836,8 +830,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_, - arg.num_group_); + arg.compute_ptr_offset_of_n_); } }; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp index 088fed89ff..d017a40bce 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_large_cases_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,6 +83,9 @@ TYPED_TEST(TestGroupedConvndFwd2d, Test2D) // When image is larger than 2GB this->conv_params.push_back( {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); this->template Run<2>(); } From 8c0ab61ece87f47e4ffece69e27c22b33f6074f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 1 Apr 2025 22:24:38 +0200 Subject: [PATCH 013/443] Grouped conv backward data GKCYX support (#2029) * Grouped conv backward data GKCYX support * profiler * Converter * split instances --- CHANGELOG.md | 5 + ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 241 +++++++++++++----- .../gpu/grid/gridwise_elementwise_2d.hpp | 113 ++++++++ ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 55 +++- .../device_grouped_conv_fwd_xdl_instance.hpp | 70 ++++- .../gpu/grouped_convolution_backward_data.hpp | 72 ++++++ .../grouped_convolution_backward_data_xdl.inc | 175 +++++++++++++ .../grouped_conv2d_bwd_data/CMakeLists.txt | 6 + ...ta_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp | 40 +++ ...kcyx_ngkhw_bf16_vec_transpose_instance.cpp | 40 +++ ...ata_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp | 40 +++ ...gkcyx_ngkhw_f16_vec_transpose_instance.cpp | 40 +++ ...ata_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp | 40 +++ ...gkcyx_ngkhw_f32_vec_transpose_instance.cpp | 40 +++ ...ta_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp | 22 +- ...ata_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp | 22 +- ...ata_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp | 22 +- ...ta_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 4 +- ...ata_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 4 +- ...ata_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 4 +- ...wd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp | 17 +- ...fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp | 17 +- ...fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp | 17 +- ...wd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp | 17 +- .../grouped_conv3d_bwd_data/CMakeLists.txt | 6 + ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 41 +++ ...zyx_ngkdhw_bf16_vec_transpose_instance.cpp | 41 +++ ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 41 +++ ...czyx_ngkdhw_f16_vec_transpose_instance.cpp | 41 +++ ..._xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp | 41 +++ ...czyx_ngkdhw_f32_vec_transpose_instance.cpp | 41 +++ ...xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp | 20 +- ..._xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp | 20 +- ..._xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp | 20 +- .../src/profile_grouped_conv_bwd_data.cpp | 38 ++- script/convert_miopen_driver_to_profiler.py | 5 +- .../test_grouped_convnd_bwd_data_xdl.cpp | 6 + 37 files changed, 1286 insertions(+), 198 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index de831a6898..8cc32e7bda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). * Added support for Stream-K version of mixed fp8/bf16 GEMM + ### Optimized None @@ -22,6 +25,8 @@ None * Removed support for gfx940 and gfx941 targets (#1944) * Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876) * DL and DPP kernels are now enabled by default. +* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. ### Known issues diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 38e9e3c3d5..770e531e44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -243,15 +243,21 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr auto I3 = Number<3>{}; using ALayoutAfterTranspose = - std::conditional_t(), + std::conditional_t(), tensor_layout::convolution::NHWGK, - std::conditional_t(), + std::conditional_t(), tensor_layout::convolution::NDHWGK, ALayout>>; + using BLayoutAfterTranspose = + std::conditional_t(), + tensor_layout::convolution::GKYXC, + std::conditional_t(), + tensor_layout::convolution::GKZYXC, + BLayout>>; using ELayoutAfterTranspose = - std::conditional_t(), + std::conditional_t(), tensor_layout::convolution::NHWGC, - std::conditional_t(), + std::conditional_t(), tensor_layout::convolution::NDHWGC, ELayout>>; @@ -265,7 +271,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DoPadGemmM, DoPadGemmN, ALayoutAfterTranspose, - BLayout, + BLayoutAfterTranspose, ELayoutAfterTranspose, true, /*SplitConvN*/ ABDataType, @@ -392,7 +398,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // block-to-e-tile map using Block2ETileMap = remove_cvref_t; - using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; static constexpr index_t ClusterLengthMPerBlock = CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -418,6 +425,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using NHWGCTransposeDescType = remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; @@ -426,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 Tuple, Tuple, Tuple, - Block2TileMapElementwise, + Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, @@ -439,12 +452,30 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 I1, I0>; + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + using GridwiseElementwiseOutputTranspose = GridwiseElementwise, Tuple, Tuple, Tuple, - Block2TileMapElementwise, + Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, @@ -498,6 +529,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + std::array b_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); std::array e_g_n_c_wis_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths, e_g_n_c_wis_strides); @@ -584,7 +618,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 a_g_n_k_wos_lengths, a_g_n_k_wos_strides_transposed, b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, + b_g_k_c_xs_strides_transposed, e_g_n_c_wis_lengths, e_g_n_c_wis_strides_transposed, conv_filter_strides, @@ -618,7 +652,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DoPadGemmM, DoPadGemmN, ALayoutAfterTranspose, - BLayout, + BLayoutAfterTranspose, DLayout, true, /*SplitConvN*/ ABDataType, @@ -627,7 +661,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 a_g_n_k_wos_lengths, a_g_n_k_wos_strides_transposed, b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, + b_g_k_c_xs_strides_transposed, ds_g_n_c_wis_lengths[i], ds_g_n_c_wis_strides[i], conv_filter_strides, @@ -682,7 +716,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // A/B/Ds/E Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0]; compute_ptr_offset_of_n_.BatchStrideA_ = @@ -692,8 +726,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { // Use not modified base strides a_in_transpose_desc_ = @@ -703,6 +737,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); @@ -710,9 +751,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); - elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapInOutElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; - elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapWeiElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapInOutElementwise{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; compute_ptr_offset_of_workspace_n_.BatchStrideA_ = @@ -724,25 +767,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::size_t GetWorkspaceATensorSizeBytes() const { - const long_index_t a_acum = ck::accumulate_n( - a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(ADataType) * a_acum; - } - - std::size_t GetWorkspaceETensorSizeBytes() const - { - const long_index_t e_accum = ck::accumulate_n( - e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); - return sizeof(EDataType) * e_accum; - } - - std::size_t GetWorkspaceSizeBytes() const - { - // Transpose require workspace for A and B - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes(); + const long_index_t a_acum = ck::accumulate_n( + a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; } else { @@ -750,6 +781,43 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + void Print() const { for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) @@ -796,11 +864,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // block-to-e-tile map std::vector block_2_etile_map_container_; - Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_, elementwise_block_2_ctile_map_transpose_e_; + Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_; NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; // for computing batch offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; @@ -835,14 +906,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const index_t gdz = arg.num_workgroups_per_Conv_N_; const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; EDataType* p_e_grid = arg.p_e_grid_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); } for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) @@ -888,7 +969,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 dim3(BlockSize), 0, p_a_grid, - arg.p_b_grid_, + p_b_grid, arg.p_ds_grid_, p_e_grid, arg.a_element_op_, @@ -925,11 +1006,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.Print(); } // Transpose from NGKHW to NHWGK - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); const auto clear_workspace = [&]() { hip_check_error(hipMemsetAsync(p_e_in_grid, @@ -938,47 +1021,72 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 stream_config.stream_id_)); }; - const index_t grid_size = + const index_t a_grid_size = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( arg.a_in_transpose_desc_) * arg.num_workgroups_per_Conv_N_; + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); auto kernel_transpose = - kernel_batched_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - element_wise::PassThrough, - I1, - I1>; + kernel_elementwise_batched_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + I1, + I1, + I1, + I1>; ave_time += launch_and_time_kernel_with_preprocess( stream_config, clear_workspace, kernel_transpose, - dim3(grid_size), + dim3(a_grid_size + b_grid_size), dim3(ElementwiseBlocksize), 0, make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, element_wise::PassThrough{}, + a_grid_size, arg.num_workgroups_per_Conv_N_, + I1, // B is not splited per N std::array{ static_cast(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)}, + std::array{0}, std::array{ - static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}); + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, + std::array{0}); } ave_time += RunGemm(arg, stream_config); // Transpose from NHWGC to NGCHW - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { const index_t grid_size = arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( @@ -987,7 +1095,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const EDataType* p_e_in_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType); + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); EDataType* p_e_out_grid = arg.p_e_grid_; @@ -997,7 +1106,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ck::Tuple, ck::Tuple, ck::Tuple, - Block2TileMapElementwise, + Block2TileMapInOutElementwise, element_wise::PassThrough, I1, I1>; @@ -1077,7 +1186,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // vector load for B matrix from global memory to LDS if constexpr(is_same_v || - is_same_v) + is_same_v || + is_same_v || + is_same_v) { if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) { @@ -1152,8 +1263,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { @@ -1320,8 +1431,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { str << ", TransposeTransferInScalarPerVectorAligned: " << TransposeTransferInScalarPerVectorAligned <<", " << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 0edfc9b0ee..1326c5d62d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -93,6 +93,119 @@ __global__ void } } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise_batched_dual( + const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size, + const index_t batch_count_a, + const index_t batch_count_b, + const std::array input_batch_strides_a, + const std::array input_batch_strides_b, + const std::array output_batch_strides_a, + const std::array output_batch_strides_b) +{ + static_assert(InAGridDescTuple::Size() == NumInputsA && + InADataTypePointerTuple::Size() == NumInputsA); + static_assert(OutAGridDescTuple::Size() == NumOutputsA && + OutADataTypePointerTuple::Size() == NumOutputsA); + static_assert(InBGridDescTuple::Size() == NumInputsB && + InBDataTypePointerTuple::Size() == NumInputsB); + static_assert(OutBGridDescTuple::Size() == NumOutputsB && + OutBDataTypePointerTuple::Size() == NumOutputsB); + + const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id()); + + if(block_id < a_grid_size) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a); + const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch); + + InADataTypePointerTuple p_in_global_with_offset_tuple; + OutADataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = + p_in_global_tuple_a.At(i) + + type_convert(input_batch_strides_a[i]) * g_idx; + }); + + static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple_a.At(i) + + type_convert(output_batch_strides_a[i]) * g_idx; + }); + + GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a, + out_grid_desc_tuple_a, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map_a, + elementwise_op, + block_id); + } + else + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b); + const index_t g_idx = + __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch); + + InBDataTypePointerTuple p_in_global_with_offset_tuple; + OutBDataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = + p_in_global_tuple_b.At(i) + + type_convert(input_batch_strides_b[i]) * g_idx; + }); + + static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple_b.At(i) + + type_convert(output_batch_strides_b[i]) * g_idx; + }); + + GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b, + out_grid_desc_tuple_b, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map_b, + elementwise_op, + block_id - a_grid_size); + } +} + template +using device_grouped_conv_bwd_data_xdl_f16_generic_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template ; // bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_xdl_bf16_generic_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template ; // f32_f32_f32_f32 +template +using device_grouped_conv_bwd_data_xdl_f32_generic_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template +using device_grouped_conv_fwd_xdl_bf16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_xdl_f16_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_xdl_int8_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + // clang-format on + >; + template && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( + op_ptrs); + } #endif } } @@ -261,6 +296,43 @@ struct DeviceOperationInstanceFactory< add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( op_ptrs); } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( + op_ptrs); + } #endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 6f82117731..5be8f29e99 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -147,6 +147,94 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( + std::vector>>& instances); +#endif + // conv3d backward data #ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( @@ -300,6 +388,93 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( PassThrough, PassThrough>>>& instances); #endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( + std::vector>>& instances); +#endif } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 50b724206e..913ebd3a12 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -10,6 +10,12 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..23aeeaf505 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp new file mode 100644 index 0000000000..b6e4c170df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp new file mode 100644 index 0000000000..beeda26690 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp new file mode 100644 index 0000000000..234fd53c8c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp new file mode 100644 index 0000000000..a1d768f4eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp new file mode 100644 index 0000000000..3a8b22924a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f32_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp index 974615c434..38c3ebc67b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( std::vector{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<2, - NGKHW, - GKYXC, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_bf16_generic_instances<2, + NGKHW, + GKYXC, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp index 272e5f3cb7..e6f3985935 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances( std::vector{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f16_instances<2, - NGKHW, - GKYXC, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_f16_generic_instances<2, + NGKHW, + GKYXC, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp index 01cd2c9206..9212c546ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances( std::vector{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f32_instances<2, - NGKHW, - GKYXC, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_f32_generic_instances<2, + NGKHW, + GKYXC, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 5d9194798b..75e7f61f8a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k] +// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_generic_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp index e002058557..78d1747548 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -23,13 +23,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances( PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_generic_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp index 1033db4972..5c8c3cb8a5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -23,13 +23,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances( PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_generic_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp index 65c75fa043..d89c29327c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -23,13 +23,14 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances( PassThrough, PassThrough>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_generic_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 4ab7335f7d..a656c79289 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -9,6 +9,12 @@ set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..a9a6b4d281 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp new file mode 100644 index 0000000000..e0703a60fd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..eec3944078 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp new file mode 100644 index 0000000000..5bbd7863da --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp new file mode 100644 index 0000000000..a596482ca8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp new file mode 100644 index 0000000000..d68062a707 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f32_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp index 88e091568c..b42eca238f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp @@ -27,20 +27,12 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( { add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_bf16_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_bf16_generic_instances<3, + NGKDHW, + GKZYXC, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp index 0378ec13cb..a66965b4a3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp @@ -27,20 +27,12 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances( { add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f16_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f16_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_f16_generic_instances<3, + NGKDHW, + GKZYXC, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp index 066fc8a3eb..af21d6dc5d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp @@ -27,20 +27,12 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances( { add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f32_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f32_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); + device_grouped_conv_bwd_data_xdl_f32_generic_instances<3, + NGKDHW, + GKZYXC, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); } } // namespace instance diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 9565833b32..1515f1105f 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -16,6 +16,7 @@ enum struct ConvLayout GNHWC_GKYXC_GNHWK, // 0 NHWGC_GKYXC_NHWGK, // 1 NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 }; enum struct ConvDataType @@ -36,9 +37,10 @@ static void print_helper_msg() << "arg2: data type (0: Output fp32, Weight fp32, Input fp32\n" << " 1: Output fp16, Weight fp16, Input fp16\n" << " 2: Output bf16, Weight bf16, Input bf16\n" - << "arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K]\n" - << " 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K])\n" - << " 2: Output[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Input[N, G, K, Ho, Wo])\n" + << "arg3: tensor layout (0: Output[G, N, Ho, Wo, C], Weight[G, K, Y, X, C], Input[G, N, Hi, Wi, K]\n" + << " 1: Output[N, Ho, Wo, G, C], Weight[G, K, Y, X, C], Input[N, Hi, Wi, G, K])\n" + << " 2: Output[N, G, C, Ho, Wo], Weight[G, K, Y, X, C], Input[N, G, K, Hi, Wi])\n" + << " 3: Output[N, G, C, Ho, Wo], Weight[G, K, C, Y, X], Input[N, G, K, Hi, Wi])\n" << "arg4: verification (0: no, 1: yes)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: print tensor value (0: no; 1: yes)\n" @@ -160,6 +162,21 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{}); } } + else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, BF16{}, BF16{}, BF16{}); + } + } } else if(num_dim_spatial == 3) { @@ -208,6 +225,21 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{}); } } + else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, BF16{}, BF16{}, BF16{}); + } + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 9bb668e164..81f9977542 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -30,10 +30,9 @@ def parse_layouts(args): if args.in_layout == "NCW" or args.in_layout == "NCHW" or \ args.in_layout == "NCDHW": if args.ck_profier_op == "grouped_conv_bwd_weight" or \ - args.ck_profier_op == "grouped_conv_fwd": + args.ck_profier_op == "grouped_conv_fwd" or \ + args.ck_profier_op == "grouped_conv_bwd_data": args.layout = 3 - elif args.ck_profier_op == "grouped_conv_bwd_data": - args.layout = 2 else: print('Not supported layout for this op') exit(1) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp index 3fe4dac2ba..eb6083c521 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp @@ -54,6 +54,9 @@ using KernelTypes2d = ::testing::Types, std::tuple, std::tuple, std::tuple, + std::tuple, + std::tuple, + std::tuple, std::tuple, std::tuple, std::tuple>; @@ -64,6 +67,9 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple, std::tuple, + std::tuple, + std::tuple, + std::tuple, std::tuple, std::tuple, std::tuple>; From e5ad48a7843a16a1ed0c1268b5dba7dfe2d59e4d Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:03:40 +0200 Subject: [PATCH 014/443] Basic docs for universal gemm & ck-tile gemm. (#2014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Basic docs for universal gemm & ck-tile gemm. * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: BartÅ‚omiej Kocot * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp Co-authored-by: spolifroni-amd * Reviewers suggestions. * Align tparam names in doc with class tparams. * More reviewers fine tuning ;) --------- Co-authored-by: BartÅ‚omiej Kocot Co-authored-by: spolifroni-amd --- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 116 +++++++++++++++++- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 103 ++++++++++++++++ .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 60 +++++++++ 3 files changed, 277 insertions(+), 2 deletions(-) mode change 100755 => 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index a8cf681995..51c223efd2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,6 +21,105 @@ namespace ck { namespace tensor_operation { namespace device { +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam GemmAccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerXDL M size of matrix-fused-multiply-add instruction. +/// @tparam NPerXDL N size of matrix-fused-multiply-add instruction. +/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). template 0) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp old mode 100755 new mode 100644 index 55639f4aee..9f6d85dd78 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -82,6 +82,109 @@ __global__ void #endif // end of if (defined(__gfx9__)) } +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerXdl M size of matrix-fused-multiply-add instruction. +/// @tparam NPerXdl N size of matrix-fused-multiply-add instruction. +/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). template struct GemmKernel { From 2ccf91488878239c8dde9b3be312b84311907a44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 2 Apr 2025 23:59:49 +0200 Subject: [PATCH 015/443] Add support for GKCYX grouped conv weight (#2023) * Grouped conv bwd weight GKCYX support * fix and changelog * fix * fix * fixes * comments * fix --- CHANGELOG.md | 3 +- .../07_grouped_convnd_fwd/README.md | 16 +- .../11_grouped_conv_bwd_weight/README.md | 12 +- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 203 +++++++++----- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 249 +++++++++++++----- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 5 + .../transform_conv_ngchw_to_nhwgc.hpp | 9 + .../grouped_convolution_backward_weight.hpp | 114 ++++++-- ...rouped_convolution_backward_weight_xdl.inc | 120 +++++++-- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 90 ++++--- ...hwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp | 0 ...kyxc_gnhwk_f16_default_pipev1_instance.cpp | 0 ...ght_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 0 ...c_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp | 0 ...kyxc_gnhwk_f32_default_pipev1_instance.cpp | 0 ...ght_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 0 ...c_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp | 0 ...ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp | 41 +++ ...gchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp} | 8 +- ...gchw_gkcyx_ngkhw_bf16_pipev5_instance.cpp} | 8 +- ..._ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp | 41 +++ ...ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp} | 8 +- ...ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp} | 8 +- ...t_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp} | 8 +- ...ht_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp} | 8 +- ...ht_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp} | 8 +- ...ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp | 2 +- ..._ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp | 2 +- ...ght_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp | 38 +++ ...nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp | 2 +- ...nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp | 2 +- ...c_nhwgk_bf16_pipev2_irregular_instance.cpp | 0 ...nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp | 2 +- ...c_nhwgk_bf16_pipev5_irregular_instance.cpp | 0 ..._nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp | 2 +- ..._nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp | 2 +- ...xc_nhwgk_f16_pipev2_irregular_instance.cpp | 0 ..._nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp | 2 +- ...xc_nhwgk_f16_pipev5_irregular_instance.cpp | 0 ...yxc_nhwgk_bf16_default_pipev2_instance.cpp | 0 ...yxc_nhwgk_bf16_default_pipev5_instance.cpp | 0 ...wgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp | 2 +- ...ht_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 2 +- ..._gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp | 0 ..._gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp | 0 ...kyxc_nhwgk_f16_default_pipev2_instance.cpp | 0 ...kyxc_nhwgk_f16_default_pipev5_instance.cpp | 0 ...ght_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 2 +- ...c_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp | 0 ...c_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp | 0 ...kyxc_nhwgk_f32_default_pipev2_instance.cpp | 0 ...kyxc_nhwgk_f32_default_pipev5_instance.cpp | 0 ...ght_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 2 +- ...c_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp | 0 ...c_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp | 0 .../grouped_conv3d_bwd_weight/CMakeLists.txt | 84 +++--- ...c_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp | 0 ..._xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 0 ..._xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp | 0 ...wgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp | 2 +- ...wgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp | 2 +- ..._ndhwgk_bf16_pipev2_irregular_instance.cpp | 0 ...wgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp | 2 +- ..._ndhwgk_bf16_pipev5_irregular_instance.cpp | 0 ...hwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp | 2 +- ...hwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp | 2 +- ...c_ndhwgk_f16_pipev2_irregular_instance.cpp | 0 ...hwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp | 2 +- ...c_ndhwgk_f16_pipev5_irregular_instance.cpp | 0 ...xc_ndhwgk_bf16_default_pipev2_instance.cpp | 0 ...xc_ndhwgk_bf16_default_pipev5_instance.cpp | 0 ...c_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp | 2 +- ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 2 +- ...kzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp | 0 ...kzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp | 0 ...kzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp | 2 +- ...yxc_ndhwgk_f16_default_pipev2_instance.cpp | 0 ...yxc_ndhwgk_f16_default_pipev5_instance.cpp | 0 ..._xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 2 +- ...gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp | 0 ...gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp | 0 ...yxc_ndhwgk_f32_default_pipev2_instance.cpp | 0 ...yxc_ndhwgk_f32_default_pipev5_instance.cpp | 0 ..._xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 2 +- ...gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp | 0 ...gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp | 0 ...dhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp | 41 +++ ...hw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp} | 8 +- ...hw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp} | 8 +- ...cdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp | 41 +++ ...dhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp} | 8 +- ...dhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp} | 8 +- ...dl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp} | 8 +- ...xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp} | 8 +- ...xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp} | 8 +- ...dhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp | 2 +- ...cdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp | 2 +- ..._xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp | 38 +++ .../src/profile_grouped_conv_bwd_weight.cpp | 36 ++- script/convert_miopen_driver_to_profiler.py | 5 +- .../test_grouped_convnd_bwd_weight.cpp | 12 +- 101 files changed, 1004 insertions(+), 356 deletions(-) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => gnhwc_gkyxc_gnhwk}/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp (100%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp} (88%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instance.cpp} (88%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp} (88%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp} (88%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp => ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp} (94%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => ngchw_gkyxc_ngkhw}/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => ngchw_gkyxc_ngkhw}/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp (95%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp (96%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp (97%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp (97%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/{ => nhwgc_gkyxc_nhwgk}/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => gndhwc_gkzyxc_gndhwk}/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => gndhwc_gkzyxc_gndhwk}/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => gndhwc_gkzyxc_gndhwk}/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp (96%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp (97%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp (97%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp (97%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ndhwgc_gkzyxc_ndhwgk}/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp (100%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp} (89%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp} (89%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp} (89%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp} (89%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp => ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp} (93%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ngcdhw_gkzyxc_ngkdhw}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp (95%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/{ => ngcdhw_gkzyxc_ngkdhw}/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp (95%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cc32e7bda..f9da2b3117 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). -* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). * Added support for Stream-K version of mixed fp8/bf16 GEMM ### Optimized @@ -26,6 +26,7 @@ None * Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876) * DL and DPP kernels are now enabled by default. * Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. ### Known issues diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md index 28a64ad733..9e96df222d 100644 --- a/client_example/07_grouped_convnd_fwd/README.md +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -30,14 +30,14 @@ List of the device operations for grouped convolution forward in CK: Table of supported cases by instance factory with XDL instruction: -| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| -|-------|---|---|---| -|bf16 |2D, 3D|2D|1D, 2D, 3D| -|fp16 |2D, 3D|2D|1D, 2D, 3D| -|fp32 |2D, 3D|2D|1D, 2D, 3D| -|int8 |2D, 3D|2D|1D, 3D| -|fp8 |3D|✗|✗| -|bf8 |3D|✗|✗| +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|2D|1D, 3D| +|fp8 |3D|✗|✗|✗| +|bf8 |3D|✗|✗|✗| Table of supported cases by instance factory with WMMA instruction: diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md index 834fd62c8f..f1ba95e9cd 100644 --- a/client_example/11_grouped_conv_bwd_weight/README.md +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -34,12 +34,12 @@ List of the device operations for grouped convolution backward weight in CK: Table of supported cases by instance factory with XDL instruction: -| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| -|-------|---|---|---| -|bf16|2D, 3D|2D, 3D|✗| -|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| -|fp16 |2D, 3D|2D, 3D|1D, 2D, 3D| -|fp32 |2D, 3D|2D, 3D|1D, 2D, 3D| +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16|2D, 3D|2D, 3D|2D, 3D|✗| +|bf16(fp32 for weight)|2D, 3D|✗|✗|1D, 2D, 3D| +|fp16 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| +|fp32 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| Table of supported cases by instance factory with WMMA instruction: diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 033b84aafc..4d730b1f37 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -218,8 +218,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using EDataType = WeiDataType; // If NGCHW then ADataType must be equal to BDataType - static_assert(!(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) || + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || is_same_v); using AElementwiseOperation = OutElementwiseOperation; @@ -376,6 +376,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using NHWGCTransposeDescType = remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; using ABCGridDescs = decltype(GetABCGridDesc()); @@ -452,6 +458,28 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle Sequence, I1, I1>; + // NPerBlock is used for the first dim which is store dimension + // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector). + // CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so + // it is more flexible to use this dim for store dimension with such scalar + // per vector. + using GridwiseElementwiseWeightTransposeCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence<1>, + I1, + I0>; using GridwiseElementwiseTranspose = GridwiseElementwise, @@ -533,12 +561,15 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); - std::array b_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, - b_g_n_c_wis_strides); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); const auto descs = conv_to_gemm_transformer_v2 @@ -550,7 +581,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle filter_spatial_lengths_, output_spatial_lengths_, b_g_n_c_wis_strides_transposed, - e_g_k_c_xs_strides, + e_g_k_c_xs_strides_transposed, a_g_n_k_wos_strides_transposed, conv_filter_strides, conv_filter_dilations, @@ -580,29 +611,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_)[I2]; - elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ - ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; - const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; - compute_ptr_offset_of_batch_.BatchStrideC_ = - Conv_K_ * Conv_C_ * - std::accumulate(begin(filter_spatial_lengths_), - end(filter_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ce_grid_desc_m_n_, GridwiseGemm::CalculateMBlock(GemmM), GridwiseGemm::CalculateNBlock(GemmN)); - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { a_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( @@ -618,17 +641,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; } + + elementwise_block_2_ctile_map_ = + is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW() + ? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0), + e_in_transpose_desc_.GetLength(I1)} + : Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0), + ce_grid_desc_m_n_.GetLength(I1)}; } std::size_t GetWorkspaceATensorSizeBytes() const { - return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; } std::size_t GetWorkspaceBTensorSizeBytes() const @@ -638,14 +679,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle std::size_t GetWorkspaceETensorSizeBytes() const { - return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + // Align to 128B + return math::integer_divide_ceil(sizeof(AccDataType) * + ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_, + 128) * + 128; } std::size_t GetWorkspaceSizeBytes() const { - // Transpose require workspace for A and B - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + // 1. We need to transpose A and B for NGCHW and NGKHW layouts + // 2. If C format is GKCYX then tranpose during second stage. + // If C format is GKYXC then just perform second stage. + // Due to the fact that E workspace is always needed, we + // allocate them as the first part of the workspace. + // [EWorkspace, AWorkspace, BWorkspace] + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + GetWorkspaceETensorSizeBytes(); @@ -672,6 +722,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; // for computing batch offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; @@ -728,11 +780,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { p_a_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType); + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); p_b_grid = type_convert(arg.p_workspace_) + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / @@ -1373,41 +1425,72 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle float avg_time = 0.f; auto launch_elementwise_kernel = [&]() { const AccDataType* p_c_grid = type_convert(arg.p_workspace_); - const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( - arg.ce_elementwise_grid_desc_m_n_) * - arg.Conv_G_; std::array in_out_batch_strides = { static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; - const auto kernel = kernel_batched_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - CDEElementwiseOperation, - I1, - I1>; + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.e_in_transpose_desc_); - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - make_tuple(arg.ce_elementwise_grid_desc_m_n_), - make_tuple(arg.ce_elementwise_grid_desc_m_n_), - make_tuple(p_c_grid), - make_tuple(arg.p_e_grid_), - arg.elementwise_block_2_ctile_map_, - arg.cde_element_op_, - arg.Conv_G_, - in_out_batch_strides, - in_out_batch_strides); + const auto kernel = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_); + } + else + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; + + const auto kernel = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + } }; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { const index_t grid_size_a = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( @@ -1417,7 +1500,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_in_transpose_desc_); ADataType* p_a_out_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType); + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / @@ -1514,7 +1597,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle if constexpr(NDimSpatial == 2) { if constexpr(!(is_NHWGC_GKYXC_NHWGK() || - is_NGCHW_GKYXC_NGKHW())) + is_NGCHW_NGKHW())) { return false; } @@ -1522,7 +1605,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle else if constexpr(NDimSpatial == 3) { if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || - is_NGCDHW_GKZYXC_NGKDHW())) + is_NGCDHW_NGKDHW())) { return false; } @@ -1597,8 +1680,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return false; } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0) { @@ -1767,8 +1850,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << NumGroupsToMerge; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { str << ", TransposeTransferSrcScalarPerVector: " << TransposeTransferSrcScalarPerVector <<", " << "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 6d2a354ce3..f40b238c8a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -165,8 +165,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using CDataType = WeiDataType; // If NGCHW then ADataType must be equal to BDataType - static_assert(!(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) || + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || is_same_v); using AElementwiseOperation = OutElementwiseOperation; @@ -301,7 +301,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock>{}; - using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; static constexpr index_t TransposeTransferSrcScalarPerVectorAligned = std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector); @@ -314,13 +314,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using NHWGCTransposeDescType = remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; - using GridwiseElementwiseTranspose = + using GridwiseInOutTranspose = GridwiseElementwise, Tuple, Tuple, Tuple, - Block2TileMapElementwise, + Block2TileMapTranspose, element_wise::PassThrough, BlockSize, MPerBlock, @@ -333,6 +339,26 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle I1, I0>; + // NPerBlock is used for the first dim which is store dimension + // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector). + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapTranspose, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence<1>, + I1, + I0>; + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, @@ -452,13 +478,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); - std::array b_g_n_c_wis_strides_transposed = - conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, - b_g_n_c_wis_strides); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, a_g_n_k_wos_strides); - + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -469,7 +497,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle filter_spatial_lengths_, output_spatial_lengths_, b_g_n_c_wis_strides_transposed, - e_g_k_c_xs_strides, + e_g_k_c_xs_strides_transposed, a_g_n_k_wos_strides_transposed, conv_filter_strides, conv_filter_dilations, @@ -487,12 +515,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; - compute_ptr_offset_of_batch_.BatchStrideC_ = - Conv_K_ * Conv_C_ * - std::accumulate(begin(filter_spatial_lengths_), - end(filter_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, @@ -503,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { a_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( @@ -520,31 +543,33 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( b_g_n_c_wis_lengths, b_g_n_c_wis_strides); - elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; - elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{ b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } } std::size_t GetWorkspaceATensorSizeBytes() const { - return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(); - } - - std::size_t GetWorkspaceBTensorSizeBytes() const - { - return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(); - } - - std::size_t GetWorkspaceSizeBytes() const - { - // Transpose require workspace for A and B - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { - return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes(); + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; } else { @@ -552,6 +577,41 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle } } + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize(); + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; @@ -562,12 +622,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; - Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, - elementwise_block_2_ctile_map_transpose_b_; + Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; + // for computing batch offset ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; @@ -621,9 +684,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; + CDataType* p_e_grid = arg.p_c_grid_; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(CDataType); + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { const index_t grid_size_a = arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( @@ -640,8 +713,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); // Different data type for A and B is not supported - auto kernel_transpose = kernel_elementwise_dual, ck::Tuple, ck::Tuple, @@ -650,8 +723,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ck::Tuple, ck::Tuple, ck::Tuple, - Block2TileMapElementwise, - Block2TileMapElementwise, + Block2TileMapTranspose, + Block2TileMapTranspose, element_wise::PassThrough>; avg_time += launch_and_time_kernel(stream_config, @@ -698,24 +771,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ComputePtrOffsetOfStridedBatch<>, has_main_loop>; - avg_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.Conv_G_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + const auto clear_workspace = [&]() { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + hip_check_error(hipMemsetAsync(p_e_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + } + }; + + avg_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_e_grid, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); }; if(has_main_k0_block_loop) @@ -726,6 +811,38 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { launch_kernel(integral_constant{}); } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size_e = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const CDataType* p_e_in_grid = static_cast(p_e_grid); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapTranspose, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_e), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(arg.p_c_grid_), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + return avg_time; } @@ -763,7 +880,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { if constexpr(!(is_NHWGC_GKYXC_NHWGK() || is_GNHWC_GKYXC_GNHWK() || - is_NGCHW_GKYXC_NGKHW())) + is_NGCHW_NGKHW())) { return false; } @@ -772,7 +889,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || is_GNDHWC_GKZYXC_GNDHWK() || - is_NGCDHW_GKZYXC_NGKDHW())) + is_NGCDHW_NGKDHW())) { return false; } @@ -810,8 +927,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return false; } - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0) { @@ -980,8 +1097,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle << CShuffleNXdlPerWavePerShuffle << ", " << CBlockTransferScalarPerVector_NWaveNPerXdl; - if constexpr(is_NGCHW_GKYXC_NGKHW() || - is_NGCDHW_GKZYXC_NGKDHW()) { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { str << ", TransposeTransferSrcScalarPerVectorAligned: " << TransposeTransferSrcScalarPerVectorAligned <<", " << "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 69913163f0..272b832e11 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -502,6 +502,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + // NPerBlock is used for the first and second dim which to use + // CDEBlockTransferScalarPerVector_NPerBlock for load and store during + // transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to + // NPerBlock so it is more flexible to use this dim for load store dimension + // with such scalar per vector. using GridwiseElementwiseInputTranspose = GridwiseElementwise, Tuple, diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp index 7bf52cb229..0f28fe8169 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp @@ -12,6 +12,15 @@ namespace ck { namespace tensor_operation { +/* + * Transform Convolution NGCHW to NHWGC. We transform [N, G, C, H, W] tensor + * descriptor to [N * G * C, H * W] (input or output image). The first + * dimension is store dimension, the second one is load dimension. For + * NHWGC to NGCHW load and store are reverted. For weight we transform + * [G, K, C, Y, X] to [G * K * Y * X, C]. First dim is load dimension, + * second dim is store dimension. + */ + template && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( + op_ptrs); + } #endif } if constexpr(is_same_v && is_same_v && @@ -443,12 +488,6 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + } #endif } if constexpr(is_same_v && is_same_v && @@ -622,12 +700,6 @@ struct DeviceOperationInstanceFactory>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_bf16_instances( +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instances( + std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f16_instances( +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instances( + std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instances( +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instances( + std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instances( +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instances( + std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp similarity index 88% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp index 9fbdc6c461..0f0817b775 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" @@ -10,10 +10,10 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp similarity index 88% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp index bbab53d9b5..7efe6f7bc1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" @@ -10,10 +10,10 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp index 74ccc4c89b..6e77488299 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp index fab2898559..4a0e89f0fe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp index 407645e893..9a0da7c431 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp index 807de66ca5..e2ecee734f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp index 084c83cd65..a65c20c840 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp index d174e5b6c0..089953dad2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp index cac9353354..678e5d234f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 96% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index ee71e37e79..54edc0d247 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index f6e1ada352..f77d88e71c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 384706414a..e6115f28a1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 860e08cafe..1b0d2dd0b2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,43 +1,49 @@ # XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp + xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp + + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp + + xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp + xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp + xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp + + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp ) if(DL_KERNELS) @@ -62,7 +68,7 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_BWD_WEIGHT - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp) endif() add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp index 63249a1c13..4c4589d128 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp index 7841ddad99..b6d8c7f635 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp index ba6285a380..5b295e728b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp index a8fbefb5bd..125b324985 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp index e4baafc0be..beb937f185 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp index f9bc5b1349..5274ff74a0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp index 679f30a3d9..767e091b94 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 96% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index f1ea371819..53011b4972 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp index 6e7f22b7e5..8a1e0b2008 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index eba721c7b8..d23b8516ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 7dd289139c..4de221a885 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..e7cfcf1e5f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp similarity index 89% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp index ac6cb82681..8d9c3c56ed 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" @@ -10,10 +10,10 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp similarity index 89% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp index 489fa81a7f..c8c6253362 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" @@ -10,10 +10,10 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_generic_instances<3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 4170ac65aa..1640b48ffd 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,6 +17,7 @@ enum struct ConvLayout GNHWC_GKYXC_GNHWK, // 1 NHWGC_GKYXC_NHWGK, // 2 NGCHW_GKYXC_NGKHW, // 3 + NGCHW_GKCYX_NGKHW, // 4 }; enum struct ConvDataType @@ -49,6 +50,8 @@ static void print_helper_msg() "Ho, Wo, G, K]\n" << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " "G, K, Ho, Wo]\n" + << " 4: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo]\n" << "arg4: verification (0: no, 1: yes)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: print tensor value (0: no; 1: yes)\n" @@ -199,6 +202,21 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) @@ -262,6 +280,22 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } } + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 81f9977542..1278b6744d 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -29,8 +29,9 @@ def run_ck_profiler_cmd(cmd): def parse_layouts(args): if args.in_layout == "NCW" or args.in_layout == "NCHW" or \ args.in_layout == "NCDHW": - if args.ck_profier_op == "grouped_conv_bwd_weight" or \ - args.ck_profier_op == "grouped_conv_fwd" or \ + if args.ck_profier_op == "grouped_conv_bwd_weight": + args.layout = 4 + elif args.ck_profier_op == "grouped_conv_fwd" or \ args.ck_profier_op == "grouped_conv_bwd_data": args.layout = 3 else: diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 54b96d775c..21f2cb5ce6 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -146,8 +146,12 @@ using KernelTypes2d = ::testing::Types< std::tuple>, std::tuple>, std::tuple>, + std::tuple>, std::tuple>, - std::tuple>>; + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>>; using KernelTypes3d = ::testing::Types< std::tuple>, std::tuple>, @@ -158,8 +162,12 @@ using KernelTypes3d = ::testing::Types< std::tuple>, std::tuple>, std::tuple>, + std::tuple>, std::tuple>, - std::tuple>>; + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>>; TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); From 9329432f6c3d4ddd8d5b836245bd44acef89be3d Mon Sep 17 00:00:00 2001 From: aledudek Date: Thu, 3 Apr 2025 13:35:43 +0200 Subject: [PATCH 016/443] Post-merge changes for fully async args copy in ck grouped gemm (#1991) * Post-merge changes for fully async args copy in ck grouped gemm * Post-merge documentation and naming changes * Build fix and updated changelog * Revised comments --- CHANGELOG.md | 2 ++ .../run_grouped_gemm_example.inc | 35 +++++++++++++------ .../device_grouped_gemm_multiple_d_dl.hpp | 15 ++++++-- .../device/impl/device_grouped_gemm_xdl.hpp | 16 +++++++-- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 16 +++++++-- 5 files changed, 68 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f9da2b3117..49ef2998eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). * Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 86b3182a52..7186c22233 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -21,6 +21,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) @@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } - if(hargs_size > 0) + if(config.async_hargs && hargs_size > 0) { hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); - gemm.SetHostKernelArgs(&argument, gemm_hargs); + gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); } if(!gemm.IsSupportedArgument(argument)) @@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - hipStream_t stream0 = nullptr; - hip_check_error(hipStreamCreate(&stream0)); + if(!config.async_hargs) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + } + else + { + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); - hipEvent_t event0 = nullptr; - hip_check_error(hipEventCreate(&event0)); + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); - hip_check_error(hipEventSynchronize(event0)); - hip_check_error(hipStreamSynchronize(stream0)); + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); + } bool pass = true; if(config.do_verification) @@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); exit(0); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index c148d7dbb7..463b10de43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -607,6 +607,9 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemmSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 2a6406aac3..d9a0249da8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -560,6 +560,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 03431d7156..a2afb62eec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -423,6 +423,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitKSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_) From 265af71a71fd81c99988365477973c337c512e13 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 3 Apr 2025 12:42:03 -0500 Subject: [PATCH 017/443] Add FP16/BF16<->FP8/BF8 conversions (#2035) * Move conversion functions and add missing conversions * Add tests * Add missing conversions * Add missing conversions * Add bf8 tests * Update clipping for vectors * Add missing conversions * Add bf16 fp8 tests * Add bf16 bf8 tests * Fix device conversion * Fix conversions * Fix vector use * Minor fix * Add a workaround flag * Add a workaround flag for bf16 conversion * Add another workaround * Add a workaround for fp16 to bf8 conversion * Update type alias * Add docstrings and missing wrappers * Fix if defined macros * Fix more if defined macros * Add comments * Remove __host__ specifier * Add a gfx950 guard * Update function naming --- include/ck/ck.hpp | 6 + include/ck/utility/amd_ck_fp8.hpp | 864 +++++++++++++++++++-- include/ck/utility/mxf8_utils.hpp | 2 +- include/ck/utility/scaled_type_convert.hpp | 4 +- include/ck/utility/type_convert.hpp | 696 ++++++++++++++++- test/data_type/test_bf8_ocp.cpp | 595 +++++++++++++- test/data_type/test_fp8_ocp.cpp | 571 +++++++++++++- 7 files changed, 2628 insertions(+), 110 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5fa73d2fda..1d49b68a32 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -248,6 +248,12 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1 + // denorm test fix, necessary for gfx90a #ifndef CK_GFX90A_DENORM_WORKAROUND #define CK_GFX90A_DENORM_WORKAROUND 0 diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 5c80c42d6c..b0089bb2d1 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -64,6 +64,9 @@ enum class ck_saturation_t namespace fp8_impl { typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); +typedef ushort ushortx2_t __attribute__((ext_vector_type(2))); +typedef short shortx2_t __attribute__((ext_vector_type(2))); typedef float float2_t __attribute__((ext_vector_type(2))); __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) @@ -270,7 +273,7 @@ static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v) { const auto i16val = bit_cast(v); @@ -458,6 +461,510 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) #endif } +#if defined(__gfx950__) +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.half_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +#else + std::ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + std::ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +} +#endif // defined(__gfx950__) + #if CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 @@ -523,6 +1030,84 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = } return i8data; } + +template +static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0) +{ + if constexpr(stochastic_rounding) + { + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f32(v[0], rng), + cast_to_f8_from_f32(v[1], rng)}; + } + else + { + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; + } val0, val1; + + val0.fval = v[0]; + val1.fval = v[1]; + + unsigned int ival = 0; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0); + } + } + else + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0); + } + } + } + + // RNE CVT + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false); + } + else + { + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false); + } + + val0.i32val = ival; + + return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]}; + } +} #endif // CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas @@ -797,6 +1382,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn * * \tparam interp interpretation of fp8 * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR * \param f float number * \return fp8_storage_t */ @@ -882,6 +1468,47 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #endif // CK_FP8_CVT_FAST_PATH } +/** + * \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param f vector of 2 floats + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f[0]); +#else + rng = prand_generator(reinterpret_cast(&f), f[0]); +#endif + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#else +__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#endif // CK_USE_OCP_FP8 + return fp8x2_storage_t{cvt_float_to_fp8(f[0]), + cvt_float_to_fp8(f[1])}; +#endif // CK_FP8_CVT_FAST_PATH +} + /** * \brief convert _Float16 to @p fp8_storage_t * @@ -900,87 +1527,168 @@ __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) #endif { - return cvt_float_to_fp8(static_cast(x)); + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x); +#else + rng = prand_generator(reinterpret_cast(&x), x); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + static_cast(x)); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 _Float16 + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + float2_t{static_cast(x[0]), static_cast(x[1])}); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert bhalf_t to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x bhalf_t value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#else +__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x)); +#else + rng = prand_generator(reinterpret_cast(&x), static_cast(x)); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + bit_cast(uint32_t{x} << 16)); // convert value to float +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 bhalf_t + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#endif +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#else + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + std::ignore = rng; + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#endif // defined(__gfx950__) + } +#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION } } // namespace fp8_impl -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_rne(X x); - -// convert fp32 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert fp32 to bf8 with rounding to nearest even -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(_Float16 x) -{ - return f8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8(x)}; -} - -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(_Float16 x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8( - x)}; -} - -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_sr(X x); - -// convert fp32 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8( - x)}; -} - -// convert fp32 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) -{ - return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(_Float16 x) -{ - return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - -// convert _Float16 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(_Float16 x) -{ - return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - #if CK_USE_OCP_FP8 using f8_t = f8_ocp_t; using bf8_t = bf8_ocp_t; diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index b7b98c6455..9046a24a3a 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -39,7 +39,7 @@ static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v) { const auto i16val = bit_cast(v); diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 9a9c53caec..f3e2bd3dd9 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -67,7 +67,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t s #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), @@ -86,7 +86,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b9aeb44999..c8127aa887 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else - return uint16_t(u.int32 >> 16); + return uint16_t(uint32_t{x} >> 16); #endif } @@ -356,6 +356,180 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(float2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(half_t x) +{ + return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(half2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(half_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(half2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(bhalf_t x) +{ + return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(bhalf_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + // Declare a template function for fp8 conversion using RNE template __host__ __device__ constexpr Y f8_convert_rne(X x); @@ -466,6 +640,172 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(half_t #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(float2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(half_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(half2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(half_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(half2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(bhalf_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(bhalf_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + // convert fp32 to fp8 template <> inline __host__ __device__ f8_fnuz_t type_convert(float x) @@ -477,17 +817,6 @@ inline __host__ __device__ f8_fnuz_t type_convert(float x) #endif } -// convert fp32 to fp8 -template <> -inline __host__ __device__ f8_ocp_t type_convert(float x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert fp8 to fp32 template <> inline __host__ __device__ float type_convert(f8_fnuz_t x) @@ -524,12 +853,39 @@ inline __host__ __device__ float2_t type_convert(f8x2_fnu #endif } +/** + * @brief Converts a f8_ocp_t value to a float value. + * + * @param x The input f8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(f8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 float values. + */ template <> inline __host__ __device__ float2_t type_convert(f8x2_ocp_t x) { #if CK_OCP_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2( - x.AsType()[Number<0>{}]); + return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast(x), false); #else return float2_t{fp8_impl::cast_from_f8( x.AsType()[Number<0>{}]), @@ -538,6 +894,229 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #endif } +/** + * @brief Converts a f8_ocp_t value to a half_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + half2_t half_vec; + half_t half_arr[2]; + } output; + output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.half_arr[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a f8_ocp_t value to a bhalf_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a float value. + * + * @param x The input bf8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(bf8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 float values. + */ +template <> +inline __host__ __device__ float2_t type_convert(bf8x2_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast(x), false); +#else + return float2_t{fp8_impl::cast_from_f8( + x.AsType()[Number<0>{}]), + fp8_impl::cast_from_f8( + x.AsType()[Number<1>{}])}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a half_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a bhalf_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + template <> inline __host__ __device__ float2_t type_convert(pk_i4_t x) { @@ -610,7 +1189,12 @@ inline __host__ __device__ f8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to fp8 +/** + * @brief Converts a half_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ template <> inline __host__ __device__ f8_ocp_t type_convert(half_t x) { @@ -621,6 +1205,22 @@ inline __host__ __device__ f8_ocp_t type_convert(half_t x) #endif } +/** + * @brief Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(half_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert fp8 to fp16 template <> inline __host__ __device__ half_t type_convert(f8_fnuz_t x) @@ -645,7 +1245,28 @@ inline __host__ __device__ bf8_fnuz_t type_convert(float x) #endif } -// convert fp32 to bf8 +/** + * @brief Converts a float value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(float x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a float value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ template <> inline __host__ __device__ bf8_ocp_t type_convert(float x) { @@ -656,6 +1277,38 @@ inline __host__ __device__ bf8_ocp_t type_convert(float x) #endif } +/** + * @brief Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert bf8 to fp32 template <> inline __host__ __device__ float type_convert(bf8_fnuz_t x) @@ -683,17 +1336,6 @@ inline __host__ __device__ bf8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to bf8 -template <> -inline __host__ __device__ bf8_ocp_t type_convert(half_t x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert bf8 to fp16 template <> inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) diff --git a/test/data_type/test_bf8_ocp.cpp b/test/data_type/test_bf8_ocp.cpp index 9d4ee38b15..285e7e69fc 100644 --- a/test/data_type/test_bf8_ocp.cpp +++ b/test/data_type/test_bf8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" using ck::bf8_ocp_t; +using ck::bf8x2_ocp_t; +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic) const auto bf8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp32x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + float2_t f32x2 = type_convert(bf8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> bf8x2 + f32x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + half2_t f16x2 = type_convert(bf8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> bf8x2 + f16x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> bf16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + bhalf2_t bf16x2 = type_convert(bf8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> bf8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + bf8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} diff --git a/test/data_type/test_fp8_ocp.cpp b/test/data_type/test_fp8_ocp.cpp index 944dd89930..bf562112c8 100644 --- a/test/data_type/test_fp8_ocp.cpp +++ b/test/data_type/test_fp8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; using ck::f8_ocp_t; +using ck::f8x2_ocp_t; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic) auto f8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp32x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + float2_t f32x2 = type_convert(fp8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> fp8x2 + f32x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + half2_t f16x2 = type_convert(fp8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> fp8x2 + f16x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> bf16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + bhalf2_t bf16x2 = type_convert(fp8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> fp8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + fp8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} From 50d1f8ff905eeabc61123864d9a805d215676a53 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 3 Apr 2025 11:48:54 -0700 Subject: [PATCH 018/443] Add the MI355 support for CK TILE GEMM (#2046) * Get the root cause of the ck tile gemm failing on mi355 * Fix the ck tile gemm on MI355 * delete the debug info --- example/ck_tile/03_gemm/CMakeLists.txt | 9 ++++++--- example/ck_tile/03_gemm/run_gemm_example.inc | 8 ++++---- test/ck_tile/gemm/CMakeLists.txt | 20 +++++++++++++++++++ .../gemm/test_gemm_pipeline_compv3.cpp | 2 +- .../gemm/test_gemm_pipeline_compv4.cpp | 2 +- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 30cfee22f6..61c3a57391 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,8 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) -target_compile_options(tile_example_gemm_universal PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 6cb40e45d1..c3b4ec609c 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -240,8 +240,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); } else if(init_method == 1) { @@ -250,8 +250,8 @@ int run_gemm_example_with_layouts(int argc, } else if(init_method == 2) { - ck_tile::FillConstant{static_cast(1)}(a_m_k); - ck_tile::FillConstant{static_cast(1)}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); } else { diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 7701e451ad..3e7296b1eb 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,8 +1,28 @@ # Currently ck_tile is only built on gfx94/gfx95 +set(EXAMPLE_GEMM_COMPILE_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS + -mllvm + -enable-noalias-to-md-conversion=0 +) + +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message("Skipping ck_tile_gemm tests for current target") endif() +endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index d81e870ffc..8944e6865d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesCompV3); #include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp index 1da0028f63..22e77fac41 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesCompV4); #include "test_gemm_pipeline_ut_cases.inc" From fed0709121365e4ce8208a1a0a988905d43a1963 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 3 Apr 2025 11:54:12 -0700 Subject: [PATCH 019/443] [New] Build up the feature of CK Tile GEMM CodeGen (#1994) * New branch for codegen changes * Fix verify function for int4 * pk_int4 codegen * Update to review comments * Remove codegen directory and rename filenames * Remove extra files; clean up CMake file * New branch for codegen changes * Fix verify function for int4 * pk_int4 codegen * Update to review comments * Remove codegen directory and rename filenames * Remove extra files; clean up CMake file * code changes for single instance * config file rename, added few more combinations in json file * Fix cmake file * Addressing review comments * Reverting files changed by merge to develop --------- Co-authored-by: ThomasNing --- CMakeLists.txt | 1 + tile_engine/CMakeLists.txt | 5 + tile_engine/ops/CMakeLists.txt | 1 + tile_engine/ops/gemm/CMakeLists.txt | 45 ++ .../gemm/configs/instance_combination.json | 60 ++ tile_engine/ops/gemm/gemm_host_api.cpp | 169 +++++ tile_engine/ops/gemm/gemm_host_api.hpp | 287 +++++++++ tile_engine/ops/gemm/gemm_instance_builder.py | 596 ++++++++++++++++++ 8 files changed, 1164 insertions(+) create mode 100755 tile_engine/CMakeLists.txt create mode 100755 tile_engine/ops/CMakeLists.txt create mode 100644 tile_engine/ops/gemm/CMakeLists.txt create mode 100644 tile_engine/ops/gemm/configs/instance_combination.json create mode 100644 tile_engine/ops/gemm/gemm_host_api.cpp create mode 100644 tile_engine/ops/gemm/gemm_host_api.hpp create mode 100755 tile_engine/ops/gemm/gemm_instance_builder.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c1ca789f5..ba57ead09a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -610,6 +610,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) PACKAGE_NAME examples ) add_subdirectory(example) + add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) endif() diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt new file mode 100755 index 0000000000..cd1a192a74 --- /dev/null +++ b/tile_engine/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(BEFORE + ${CMAKE_CURRENT_LIST_DIR}/include + ) + +add_subdirectory(ops) diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt new file mode 100755 index 0000000000..0cf2c16da2 --- /dev/null +++ b/tile_engine/ops/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(gemm) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt new file mode 100644 index 0000000000..d28017ca0c --- /dev/null +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -0,0 +1,45 @@ + + +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --list_blobs + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) + +add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --gen_blobs + DEPENDS ${GEMM_CODEGEN_BLOBS} +) + +set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") +message("adding example ${EXECUTABLE_GEMM_INSTANCE}") + +# use build as include directory +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp) +target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS}) + +set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) + +list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress) + +target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json new file mode 100644 index 0000000000..e21197d1de --- /dev/null +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -0,0 +1,60 @@ +{ + + "layout_a": { + "values": ["r"] + }, + "layout_b": { + "values": ["c"] + }, + "layout_c": { + "values": ["r"] + }, + "datatype": { + "values": ["fp16"] + }, + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [64] + }, + "warp_m": { + "values": [2] + }, + "warp_n": { + "values": [2] + }, + "warp_k": { + "values": [1] + }, + "warp_tile_m": { + "values": [32] + }, + "warp_tile_n": { + "values": [32] + }, + "warp_tile_k": { + "values": [16] + }, + "kPadM": { + "values": [false] + }, + "kPadN": { + "values": [false] + }, + "kPadK": { + "values": [false] + }, + "pipeline": { + "values": ["compv3", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["default", "cshuffle"] + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp new file mode 100644 index 0000000000..508f634920 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "gemm_common.hpp" +#include "gemm_dispatcher.hpp" +#include "gemm_host_api.hpp" + +float gemm_kernel_launch(KernelTraits& trait, + ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) +{ + return GemmDispatcher::dispatch(trait, args, s); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + const ALayout a_layout = ALayout{}; + const BLayout b_layout = BLayout{}; + // const CLayout c_layout = CLayout{}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + int verify = arg_parser.get_int("v"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs gemm_args; + gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + gemm_args.k_batch = kbatch; + gemm_args.M = M; + gemm_args.N = N; + gemm_args.K = K; + gemm_args.stride_A = stride_A; + gemm_args.stride_B = stride_B; + gemm_args.stride_C = stride_C; + + KernelTraits trait; + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.kPadM = arg_parser.get_bool("pad_m"); + trait.kPadN = arg_parser.get_bool("pad_n"); + trait.kPadK = arg_parser.get_bool("pad_k"); + + float ave_time = gemm_kernel_launch( + trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + if(verify) + { + pass = gemm_verify( + verify, + a_m_k, + b_k_n, + c_m_n_dev_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch); + } + return pass; +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + return run(parser); + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp new file mode 100644 index 0000000000..4f0ea52a18 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -0,0 +1,287 @@ +#include + +#include +#include +#include +#include +#include +#include "ck_tile/ops/gemm.hpp" + +#pragma once + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +struct KernelTraits +{ + std::string pipeline; + std::string scheduler; + std::string epilogue; + bool kPadM; + bool kPadN; + bool kPadK; +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("split_k", "1", "splitK value") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("pipeline", "compv3", "compv3, compv4, mem") + .insert("scheduler", "intrawave", "intrawave, interwave") + .insert("epilogue", "cshuffle", "cshuffle, default") + .insert("pad_m", "false", "true, false") + .insert("pad_n", "false", "true, false") + .insert("pad_k", "false", "true, false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} + +// verification code +template +bool gemm_verify(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch) +{ + bool pass = true; + if(verify == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); + ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); + ck_tile::hip_check_error( + hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + a_m_k.get_element_space_size_in_bytes(), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + b_k_n.get_element_space_size_in_bytes(), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + c_m_n_dev_result.get_element_space_size_in_bytes(), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py new file mode 100755 index 0000000000..c0dad03ef0 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -0,0 +1,596 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Dict, Any +import functools +import itertools +import copy +import json +from dataclasses import dataclass + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::half_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t', + 'bf8' : 'ck_tile::bf8_t', + 'int4' : 'ck_tile::pk_int4_t' + } + +LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + +DEFAULT_EPILOGUE = """ + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; +""" + +CSHUFFLE_EPILOGUE = """ + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; +""" +HOT_LOOP_FALSE = """ + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); + } +""" +RUN_MEM = """ + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers"); + } +""" + +RUN_COMPV3 = """ + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); + } +""" + +RUN_COMPV4 = """ + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +""" + + +PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], + 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], + 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} + +SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', + 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} + +EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, + 'cshuffle' : CSHUFFLE_EPILOGUE} + +HOT_LOOP_TRUE = {'mem' : RUN_MEM, + 'compv3' : RUN_COMPV3, + 'compv4' : RUN_COMPV4} + + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +@dataclass +class GemmConfig: + def __init__(self, config_data): + self.matrix_cfg : Dict[str, Any] = {} + self.impl_cfg : Dict[str, Any] = {} + for key, value in config_data.items(): + if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + self.matrix_cfg[key] = value + else: + self.impl_cfg[key] = value + + @property + def datatype(self) -> str: + return self.matrix_cfg["datatype"]["values"][0] + + @property + def layouts(self) -> List[str]: + return [ + self.matrix_cfg["layout_a"]["values"][0], + self.matrix_cfg["layout_b"]["values"][0], + self.matrix_cfg["layout_c"]["values"][0] + ] + + +class GemmCodeGenerator: + def __init__(self, output_dir: str, config: GemmConfig): + self.output_dir = Path(output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir() + + self.config = config + self.all_kernels = [] + self.unique_configs = [] + # Validate configurations + self._validate_config() + + def _validate_config(self): + """Validate matrix and implementation configurations""" + # Matrix config validation + for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + if len(self.config.matrix_cfg[param]["values"]) != 1: + raise ValueError(f"Matrix config {param} must have exactly one value") + + # Implementation traits validation + required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k", + "warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline", + "epilogue", "scheduler", "kPadM", "kPadN", "kPadK"] + for param in required_params: + if not self.config.impl_cfg.get(param, {}).get("values"): + raise ValueError(f"Missing implementation parameter: {param}") + + def list_all(self): + """List all possible kernel configurations""" + w_p = Path(self.output_dir) + list_p = w_p / 'gemm_instance_blobs.txt' + self._list_config_groups() + with list_p.open('w') as list_f: + list_f.write(str(w_p / ("gemm_common.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n") + for group in self.all_kernels: + list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n") + + + + def _list_config_groups(self): + params = [ + ("pipeline", "pipeline"), + ("epilogue", "epilogue"), + ("scheduler", "scheduler"), + ("kPadM", "kPadM"), + ("kPadN", "kPadN"), + ("kPadK", "kPadK") + ] + + # Generate all unique_combinations + _unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params])) + for combo in _unique: + config = {name: value for (_, name), value in zip(params, combo)} + pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values() + # To remove some unsupported combinations + unsupported_combination = [("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave")] + if (pipeline, epilogue, scheduler) not in unsupported_combination: + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + self.all_kernels.append(group_name) + self.unique_configs.append(config) + + def generate_all(self): + self._generate_common_header() + self._generate_config_groups() + self._generate_dispatcher() + + + def _generate_common_header(self): + """Generate common header with datatypes and layout""" + ctype = self.config.datatype + atype = self.config.datatype + btype = self.config.datatype + if self.config.datatype in ['fp8', 'bf8']: + ctype = 'fp16' + elif self.config.datatype in ['int4']: + atype = 'fp16' + ctype = 'fp16' + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" + +// Data types +using ADataType = {DATA_TYPE_MAP[atype]}; +using BDataType = {DATA_TYPE_MAP[btype]}; +using AccDataType = float; +using CDataType = {DATA_TYPE_MAP[ctype]}; + +// Layout configurations +using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; +using BLayout = {LAYOUT_MAP[self.config.layouts[1]]}; +using CLayout = {LAYOUT_MAP[self.config.layouts[2]]}; +""" + + + (self.output_dir / "gemm_common.hpp").write_text(content) + + def _generate_config_groups(self): + """Generate implementation configuration groups""" + if not self.unique_configs: # Check if the list is empty + self._list_config_groups() + for config in self.unique_configs: + self._generate_config_group(**config) + self.generate_common_instances_header() + + + def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool): + """Generate a configuration group with all tile/warp combinations""" + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + filename = f"gemm_{group_name}.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_common.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/host.hpp" + +namespace {group_name} {{ +""" + # Add template struct with configuration + content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK) + + content += f"\n}} // namespace {group_name}\n" + (self.output_dir / filename).write_text(content) + + def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool) -> str: + """Generate kernel struct template""" + return f""" +template +struct GemmKernel {{ + static constexpr bool kPadM = {BOOL_MAP(kPadM)}; + static constexpr bool kPadN = {BOOL_MAP(kPadN)}; + static constexpr bool kPadK = {BOOL_MAP(kPadK)}; + + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ + static constexpr bool permuteA = false; + static constexpr bool permuteB = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; + {EPILOGUE_MAP[epilogue]} + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + if(s.log_level_ > 0) + {{ + std::cout << "Launching kernel with args:" + << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + return ave_time; + + }}; + + if(has_hot_loop) {{ + {HOT_LOOP_TRUE[pipeline]} + }} else {{ + {HOT_LOOP_FALSE} + }} + + return ave_time; + }} +}}; +""" + + def generate_common_instances_header(self): + """Generate common instances header""" + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +""" + for group in self.all_kernels: + content += f"#include \"gemm_{group}.hpp\"\n" + (self.output_dir / "gemm_instances.hpp").write_text(content) + + def _generate_dispatcher(self): + """Generate dispatch mechanism""" + content = """// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_common.hpp" +#include "gemm_instances.hpp" +#include "gemm_host_api.hpp" +#include +#include +#include + +struct GemmDispatcher { + static auto& get_kernel_map() { + // Use a static local variable + static std::unordered_map> kernel_map; + return kernel_map; + } + + static void init() { + auto& kernel_map = get_kernel_map(); + if(!kernel_map.empty()) return; + \n""" + # Add tile/warp instantiations + tile_params = set(itertools.product( + self.config.impl_cfg["tile_m"]["values"], + self.config.impl_cfg["tile_n"]["values"], + self.config.impl_cfg["tile_k"]["values"], + self.config.impl_cfg["warp_m"]["values"], + self.config.impl_cfg["warp_n"]["values"], + self.config.impl_cfg["warp_k"]["values"], + self.config.impl_cfg["warp_tile_m"]["values"], + self.config.impl_cfg["warp_tile_n"]["values"], + self.config.impl_cfg["warp_tile_k"]["values"] + )) + + + for group in self.all_kernels: + content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) {{ + std::vector results;""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + content += f""" + //we can have multiple tiles config for the one kernel_trait + return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);""" + content += """ + };\n""" + + content += """ } + + + static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + const ck_tile::stream_config& s) { + init(); + const std::string key = assemble_key(trait); + auto& kernel_map = get_kernel_map(); + if(auto it = kernel_map.find(key); it != kernel_map.end()) { + return it->second(gemm_args, s); //Running single instance + } + throw std::runtime_error("No suitable kernel found: " + key); + } + +private: + static std::string assemble_key(const KernelTraits &trait) { + return std::string(trait.pipeline) + "_" + + trait.epilogue + "_" + + trait.scheduler + "_" + + "pad_" + + (trait.kPadM ? "true" : "false") + "_" + + (trait.kPadN ? "true" : "false") + "_" + + (trait.kPadK ? "true" : "false"); + } +}; + +""" + (self.output_dir / "gemm_dispatcher.hpp").write_text(content) + + +def do_list_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.list_all() + +def do_gen_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.generate_all() + + + +def main(args): + # Read and validate json file + with open(args.json, 'r') as json_file: + config_data = json.load(json_file) + + # Validate and parse configuration + gemm_config = GemmConfig(config_data) + + if args.list_blobs: + do_list_blobs(args, gemm_config) + elif args.gen_blobs: + do_gen_blobs(args, gemm_config) + else: + # If neither was specified, either do nothing or default to gen_blobs + print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...") + do_gen_blobs(args, gemm_config) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm kernel", + ) + parser.add_argument( + "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" + ) + parser.add_argument( + "-j", "--json", required=True, help="Path to the json which contains the kernel configurations" + ) + parser.add_argument( + "-l", "--list_blobs", action = 'store_true', help="List all kernel to file" + ) + parser.add_argument( + "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files" + ) + + args = parser.parse_args() + + main(args) From 7142d8003c6a99f952a62bbd0b90d5f0261fc807 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Thu, 3 Apr 2025 14:22:43 -0700 Subject: [PATCH 020/443] CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test (#2044) * fix and split gemm_universal test * clang * Update test_gemm_universal_ut_cases_bf16.inc * Update test_gemm_universal_xdl_bf16.cpp * Update test_gemm_universal_ut_cases_fp16.inc --- .../profile_gemm_universal_streamk_impl.hpp | 2 +- test/gemm_universal/CMakeLists.txt | 15 ++- ... => test_gemm_universal_ut_cases_bf16.inc} | 60 +++------- .../test_gemm_universal_ut_cases_fp16.inc | 99 +++++++++++++++ .../test_gemm_universal_ut_cases_fp8.inc | 113 ++++++++++++++++++ ...l.cpp => test_gemm_universal_xdl_bf16.cpp} | 34 ++---- .../test_gemm_universal_xdl_fp16.cpp | 82 +++++++++++++ .../test_gemm_universal_xdl_fp8.cpp | 71 +++++++++++ .../test_gemm_universal_streamk_util.hpp | 12 +- 9 files changed, 409 insertions(+), 79 deletions(-) mode change 100644 => 100755 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100644 => 100755 test/gemm_universal/CMakeLists.txt rename test/gemm_universal/{test_gemm_universal_ut_cases.inc => test_gemm_universal_ut_cases_bf16.inc} (75%) create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc rename test/gemm_universal/{test_gemm_universal_xdl.cpp => test_gemm_universal_xdl_bf16.cpp} (61%) create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp16.cpp create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp8.cpp diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100644 new mode 100755 index d145ab1766..e625fae808 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size != -1) + if(Grid_size == -1) { grid_size_list = {Grid_size}; } diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100644 new mode 100755 index 4aab6323cc..cf5c68e220 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,4 +1,15 @@ -add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) +add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) endif() + +add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) +endif() + diff --git a/test/gemm_universal/test_gemm_universal_ut_cases.inc b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases.inc rename to test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc index 9a21666856..8a6c672a9f 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } - -TYPED_TEST(TestGemmUniversal_KM_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} - -TYPED_TEST(TestGemmUniversal_KM_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc new file mode 100644 index 0000000000..b61ea0e6b4 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -0,0 +1,99 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc new file mode 100644 index 0000000000..b831e15e9c --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl.cpp rename to test/gemm_universal/test_gemm_universal_xdl_bf16.cpp index b872d7089a..8fde65657a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp @@ -7,8 +7,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" -using F8 = ck::f8_t; -using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -29,25 +27,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_MK_KN +class TestGemmUniversal_BF16_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_MK_NK +class TestGemmUniversal_BF16_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_KN +class TestGemmUniversal_BF16_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_NK +class TestGemmUniversal_BF16_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; @@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases.inc" +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp new file mode 100644 index 0000000000..24f587daf6 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp new file mode 100644 index 0000000000..e833ab7825 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; + + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + + +#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index ef3509c0ca..805587a274 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - for(auto grid_size : grid_size_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); - } + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); + } } void RunSingle(const int M, From 572cd820ce720aed32168660f7d3d41304390776 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 3 Apr 2025 15:30:21 -0700 Subject: [PATCH 021/443] Split env.hpp header from the ck.hpp header. (#2049) * split env.hpp out of main headers * fix namespace logic --- include/ck/ck.hpp | 5 ----- include/ck/host_utility/flush_cache.hpp | 1 + include/ck/host_utility/kernel_launch.hpp | 1 + ...batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 1 + .../device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp | 3 ++- ...evice_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 3 ++- .../impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 3 ++- ...fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 3 ++- ...v2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 3 ++- .../impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 3 ++- .../device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 3 ++- .../device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 3 ++- .../device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 1 + .../device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 3 ++- .../ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp | 3 ++- .../gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp | 3 ++- .../gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp | 3 ++- ...ice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 1 + ...evice_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 2 +- .../impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + .../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 1 + ...device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 1 + .../gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp | 1 + ...rouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp | 3 ++- .../gpu/device/impl/device_grouped_gemm_xdl.hpp | 1 + .../device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp | 1 + .../gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 1 + .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 3 ++- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 1 + .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 3 ++- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 3 ++- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 1 + .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 3 ++- include/ck/utility/env.hpp | 5 +++++ include/ck_tile/core.hpp | 1 - include/ck_tile/core/config.hpp | 6 ------ include/ck_tile/core/utility/env.hpp | 4 ++++ include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 1 + .../include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp | 1 + profiler/include/profiler/profile_grouped_gemm_impl.hpp | 3 ++- .../profile_grouped_gemm_multiply_tile_loop_impl.hpp | 1 + .../profiler/profile_grouped_gemm_tile_loop_impl.hpp | 1 + 42 files changed, 64 insertions(+), 31 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1d49b68a32..9d5d5fbc0b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -6,15 +6,10 @@ #include "ck/config.h" #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include "ck/utility/env.hpp" #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" #endif - -// environment variable to enable logging: -// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED -CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #endif // to do: add various levels of logging with CK_LOG_LEVEL diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 918fb28ea9..08b3aba2b3 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -8,6 +8,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/stream_config.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/utility/flush_icache.hpp" diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 5c1c1c4e60..11a1c9bbc0 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/stream_config.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index f6c228fb7b..d38698af4b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 30ae72a63e..de7d67f08b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 2662e5c360..bae5c6019d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 0b73317c5e..d4f89b3e09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 13eb23574f..a8eb73d730 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 28778d825b..6eb9281d30 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 7fa231d4f4..5fad21f521 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 3be7313d2b..c7aa54f1d9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 8aa20f7ad4..68ec8187a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP @@ -10,6 +10,7 @@ #include "device.hpp" #include "device_conv_fwd.hpp" #include "common_header.hpp" +#include "ck/utility/env.hpp" #include "tensor_layout.hpp" #include "convolution_forward_specialization.hpp" #include "tensor_descriptor.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 1edae33be3..ddabd61c3d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index de8f35a640..2881036bee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eb0fb55f5d..7faee161c1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index fd6f3b65f2..213501468a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index c2a27ebbdb..7315fe75a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 770e531e44..08edddf107 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -8,6 +8,7 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 4d730b1f37..da7c4f759b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -8,7 +8,7 @@ #include #include "ck/utility/common_header.hpp" - +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index f40b238c8a..c904b4e7d5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 272b832e11..c0148c3b9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index b2f1dbfa5c..a93e6ded96 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 463b10de43..10d8a4a44d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index d692aa05ce..18872e38ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index d9a0249da8..aa70a24fc1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index a2afb62eec..01f52881f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -7,6 +7,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index cc8ae1806a..e5e32a8535 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 9f6d85dd78..29150c0688 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index ffa01efe17..a22fc06a50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 27818b6964..7124687d5d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index b805f600d5..ac3e821340 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 715fcbcfef..c204b95d0f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 6ee279a3f1..256b495c6e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 809f302f74..469fb70f10 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -184,4 +184,9 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck + +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #endif diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 821b3a8e84..d9aa8b3551 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -59,7 +59,6 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index b1d201e30e..978f673346 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -28,12 +28,6 @@ #include "hip/hip_fp16.h" #endif -#include "ck_tile/core/utility/env.hpp" - -// environment variable to enable logging: -// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED -CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) - #ifdef __HIPCC__ #define CK_TILE_HOST inline __host__ #define CK_TILE_DEVICE inline __device__ diff --git a/include/ck_tile/core/utility/env.hpp b/include/ck_tile/core/utility/env.hpp index 5b0b7a9071..9b148b3e0b 100644 --- a/include/ck_tile/core/utility/env.hpp +++ b/include/ck_tile/core/utility/env.hpp @@ -202,3 +202,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck_tile + +// environment variable to enable logging: +// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e5b9d17bac..bc41f680f2 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" namespace ck_tile { diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 09e03de99c..8fb20f0135 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 367e94de11..fc2ba5a650 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp index 94ee2a37e4..1b17f05760 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 3a4ca24dda..cf3c3a6bae 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" From b443056a26cd25e6e621ff1c026b02eefdfe1f29 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 3 Apr 2025 16:24:34 -0700 Subject: [PATCH 022/443] Documentation for newly added struct (#2051) --- tile_engine/ops/gemm/gemm_host_api.hpp | 17 ++++++++++++++++- tile_engine/ops/gemm/gemm_instance_builder.py | 3 +-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 4f0ea52a18..3fa6dca863 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -54,6 +54,17 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; +/** + * @brief trait for GEMM kernel + * @param pipeline: pipeline name + * @param scheduler: scheduler name + * @param epilogue: epilogue name + * @param kPadM: padding for M dimension + * @param kPadN: padding for N dimension + * @param kPadK: padding for K dimension + * + */ + struct KernelTraits { std::string pipeline; @@ -173,7 +184,11 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -// verification code +/** + * @brief Function to verify the kernel output with reference implementation on CPU/GPU + * + */ + template Date: Thu, 3 Apr 2025 16:55:49 -0700 Subject: [PATCH 023/443] file clang formatted (#2053) --- tile_engine/ops/gemm/gemm_host_api.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 3fa6dca863..375f808966 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -56,13 +56,13 @@ struct DataTypeTraits /** * @brief trait for GEMM kernel - * @param pipeline: pipeline name - * @param scheduler: scheduler name - * @param epilogue: epilogue name - * @param kPadM: padding for M dimension - * @param kPadN: padding for N dimension - * @param kPadK: padding for K dimension - * + * @param pipeline: pipeline name + * @param scheduler: scheduler name + * @param epilogue: epilogue name + * @param kPadM: padding for M dimension + * @param kPadN: padding for N dimension + * @param kPadK: padding for K dimension + * */ struct KernelTraits @@ -186,7 +186,7 @@ void permute_vectors_i4x4_b(Tensor& tensor) /** * @brief Function to verify the kernel output with reference implementation on CPU/GPU - * + * */ template Date: Mon, 7 Apr 2025 14:18:01 +0800 Subject: [PATCH 024/443] Add new receipt (#2055) --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 7 +++++++ example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6326a97f8e..94f89256f9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -545,6 +545,13 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode in ["batch", "group"] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e5d11c6dc9..d978cc1d9b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -536,6 +536,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # Aiter aiter::mha_fwd integration + elif receipt == 500: + cond = dtype in ['fp16', 'bf16'] + cond &= mode in ['batch', 'group'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) From 29f72662165bcdfa746b1a247d9c8487cbb68f2e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Apr 2025 06:49:36 -0700 Subject: [PATCH 025/443] =?UTF-8?q?Revert=20"CkProfiler=20StreamK=20GemmUn?= =?UTF-8?q?iversal=20Fix=20and=20Split=20Gemm=5Funiversal=20Test=20=20(?= =?UTF-8?q?=E2=80=A6"=20(#2054)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 7142d8003c6a99f952a62bbd0b90d5f0261fc807. --- .../profile_gemm_universal_streamk_impl.hpp | 2 +- test/gemm_universal/CMakeLists.txt | 15 +-- ...6.inc => test_gemm_universal_ut_cases.inc} | 60 +++++++--- .../test_gemm_universal_ut_cases_fp16.inc | 99 --------------- .../test_gemm_universal_ut_cases_fp8.inc | 113 ------------------ ...l_bf16.cpp => test_gemm_universal_xdl.cpp} | 34 ++++-- .../test_gemm_universal_xdl_fp16.cpp | 82 ------------- .../test_gemm_universal_xdl_fp8.cpp | 71 ----------- .../test_gemm_universal_streamk_util.hpp | 12 +- 9 files changed, 79 insertions(+), 409 deletions(-) mode change 100755 => 100644 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100755 => 100644 test/gemm_universal/CMakeLists.txt rename test/gemm_universal/{test_gemm_universal_ut_cases_bf16.inc => test_gemm_universal_ut_cases.inc} (75%) delete mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc delete mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc rename test/gemm_universal/{test_gemm_universal_xdl_bf16.cpp => test_gemm_universal_xdl.cpp} (61%) delete mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp16.cpp delete mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp8.cpp diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100755 new mode 100644 index e625fae808..d145ab1766 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size == -1) + if(Grid_size != -1) { grid_size_list = {Grid_size}; } diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100755 new mode 100644 index cf5c68e220..4aab6323cc --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,15 +1,4 @@ -add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) +add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) endif() - -add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) -endif() - -add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) -endif() - diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc b/test/gemm_universal/test_gemm_universal_ut_cases.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc rename to test/gemm_universal/test_gemm_universal_ut_cases.inc index 8a6c672a9f..9a21666856 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,3 +207,35 @@ TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } + +TYPED_TEST(TestGemmUniversal_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmUniversal_KM_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc deleted file mode 100644 index b61ea0e6b4..0000000000 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc deleted file mode 100644 index b831e15e9c..0000000000 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) -{ - std::vector Ms{1, 2, 3, 4, 5, 6}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) -{ - std::vector Ms{127}; - constexpr int N = 512; - constexpr int K = 437; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} diff --git a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl_bf16.cpp rename to test/gemm_universal/test_gemm_universal_xdl.cpp index 8fde65657a..b872d7089a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -7,6 +7,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" +using F8 = ck::f8_t; +using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -27,25 +29,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_BF16_MK_KN +class TestGemmUniversal_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_MK_NK +class TestGemmUniversal_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_KM_KN +class TestGemmUniversal_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_BF16_KM_NK +class TestGemmUniversal_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -53,12 +55,22 @@ class TestGemmUniversal_BF16_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - + std::tuple< F16, F16, F16, F16>, +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - + std::tuple< F16, F16, F16, F16>, +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif std::tuple< BF16, BF16, BF16, BF16> >; @@ -74,9 +86,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases_bf16.inc" +#include "test_gemm_universal_ut_cases.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp deleted file mode 100644 index 24f587daf6..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" - -using F8 = ck::f8_t; -using F16 = ck::half_t; - -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP16_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP16_KM_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - -#endif - std::tuple< F16, F16, F16, F16> - >; - -// clang-format on - -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); - -#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp deleted file mode 100644 index e833ab7825..0000000000 --- a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_gemm_universal_util.hpp" - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -namespace { - -template -struct tuple_concat; - -template -struct tuple_concat, std::tuple> -{ - using type = std::tuple; -}; - -} // namespace - -template -class TestGemmUniversal_FP8_MK_KN - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -template -class TestGemmUniversal_FP8_MK_NK - : public ck::test::TestGemmUniversal, Tuple>::type> -{ -}; - -// clang-format off -using KernelTypes_MK_KN = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif - // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; -using KernelTypes_MK_NK = ::testing::Types< - // ADataType, BDataType, ComputeDataType, CDataType - -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif - // Fallback test type when FP8 is not enabled - std::tuple< F16, F16, F16, F16> - >; - - -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); - - -#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index 805587a274..ef3509c0ca 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,8 +44,9 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -57,9 +58,10 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); - } + for(auto grid_size : grid_size_list) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); + } } void RunSingle(const int M, From 179322842274a635f6bd6141c7251a2f65b5fa34 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Apr 2025 07:08:39 -0700 Subject: [PATCH 026/443] fix codegen issues (#2052) --- include/ck/utility/amd_ck_fp8.hpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index b0089bb2d1..d079639c6a 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -557,7 +557,7 @@ template = false> static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) { - std::ignore = rng; + ignore = rng; union { @@ -596,7 +596,7 @@ static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rn cast_to_f8_from_f16(v[0], rng), cast_to_f8_from_f16(v[1], rng)}; #else - std::ignore = rng; + ignore = rng; union { @@ -634,7 +634,7 @@ template = false> static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) { - std::ignore = rng; + ignore = rng; union { @@ -673,7 +673,7 @@ static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rn cast_to_f8_from_f16(v[0], rng), cast_to_f8_from_f16(v[1], rng)}; #else - std::ignore = rng; + ignore = rng; union { @@ -805,7 +805,7 @@ template = false> static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) { - std::ignore = rng; + ignore = rng; union { @@ -847,7 +847,7 @@ static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned in cast_to_f8_from_bf16(v[0], rng), cast_to_f8_from_bf16(v[1], rng)}; #else - std::ignore = rng; + ignore = rng; union { @@ -891,7 +891,7 @@ template = false> static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) { - std::ignore = rng; + ignore = rng; union { @@ -928,7 +928,7 @@ template = false> static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) { - std::ignore = rng; + ignore = rng; union { @@ -1544,7 +1544,7 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(x, rng); #else - std::ignore = rng; + ignore = rng; return cvt_float_to_fp8( static_cast(x)); #endif // defined(__gfx950__) @@ -1586,7 +1586,7 @@ __host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(x, rng); #else - std::ignore = rng; + ignore = rng; return cvt_float_to_fp8( float2_t{static_cast(x[0]), static_cast(x[1])}); #endif // defined(__gfx950__) @@ -1629,7 +1629,7 @@ __host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(x, rng); #else - std::ignore = rng; + ignore = rng; return cvt_float_to_fp8( bit_cast(uint32_t{x} << 16)); // convert value to float #endif // defined(__gfx950__) @@ -1678,7 +1678,7 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(x, rng); #else - std::ignore = rng; + ignore = rng; return cvt_float_to_fp8( float2_t{bit_cast(uint32_t{x[0]} << 16), bit_cast(uint32_t{x[1]} << 16)}); // convert values to float From 72c0261ef1b40587ee8674b9d49b4fd6b46b0335 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 7 Apr 2025 12:48:34 -0700 Subject: [PATCH 027/443] Fix a couple of CI issues. (#2050) * fix jenkins jobs * fix perf log name for gfx908 * only run gemm perf tests on gfx908 --- Jenkinsfile | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 86cac3c485..dbd484d7bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -331,8 +331,10 @@ def cmake_build(Map conf=[:]){ } } else{ - // run unit tests - sh "make check" + // run unit tests unless building library for all targets + if (!params.BUILD_INSTANCES_ONLY){ + sh "make check" + } } } } @@ -604,12 +606,9 @@ def Build_CK(Map conf=[:]){ else if ( arch_type == 6 ){ // run standard tests on gfx908 echo "Run performance tests" - sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm_gfx908.log" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" archiveArtifacts "perf_onnx_gemm_gfx908.log" - archiveArtifacts "perf_resnet50_N256_gfx908.log" - archiveArtifacts "perf_resnet50_N4_gfx908.log" - stash includes: "perf_**.log", name: "perf_log_gfx908" + stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" } } } @@ -746,8 +745,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 22 * * * % ROCMVERSION=6.3;BUILD_GFX908=true;BUILD_GFX12=false;RUN_PERFORMANCE_TESTS=false - 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true + 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true; 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false From 80aae6119b47d02ffebaa0d9b153fb075d0da140 Mon Sep 17 00:00:00 2001 From: aledudek Date: Tue, 8 Apr 2025 12:40:04 +0200 Subject: [PATCH 028/443] [CK_TILE] Fix GEMM Memory Pipeline (#2034) * [CK_TILE] Fix GEMM Memory Pipeline * Fix transpose tile * Add comments --- .../ck_tile/core/tensor/transpose_tile.hpp | 108 +++++++++++------- 1 file changed, 69 insertions(+), 39 deletions(-) diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index f34efe5c2f..5b65b79c1a 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -83,9 +83,6 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr index_t num_vec_in = vec_length_out; constexpr index_t num_vec_out = vec_length_in; - using InVec = array; - using OutVec = array; - // SFC constexpr auto scalars_per_access_arr = generate_array( [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, @@ -101,51 +98,84 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, static_assert(num_access > 0, "wrong! num_access should be larger than 0"); - // in/out vectors to be transposed - thread_buffer in_vectors; - thread_buffer out_vectors; + if constexpr(num_vec_in == 1 || num_vec_out == 1) + { + // loop over SFC + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y = SFC_Y::get_index(iAccess); - // loop over SFC and do transpose - static_for<0, num_access, 1>{}([&](auto iAccess) { - // data index [y0, y1, ...] in the order of input tensor - constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y); + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y); - // get input vectors - static_for<0, num_vec_in, 1>{}([&](auto i) { - constexpr auto idx_y_in = generate_tuple( - [&](auto ii) { - return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; - }, - number{}); - - constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); - static_assert(in_offset % vec_length_in == 0); - - in_vectors(i).template get_as()(I0) = - in_tensor.get_thread_buffer() - .template get_as()[number{}]; + if constexpr(vec_length_in == 1) + { + out_tensor.get_thread_buffer()[number{}] = + in_tensor.get_thread_buffer()[number{}]; + } + else + { + using Vec = array; + out_tensor.get_thread_buffer().template get_as( + number{}) = + in_tensor.get_thread_buffer().template get_as( + number{}); + } }); + } + else + { + using InVec = array; + using OutVec = array; - // transpose - transpose_vectors{}(in_vectors, out_vectors); + // in/out vectors to be transposed + thread_buffer in_vectors; + thread_buffer out_vectors; - // set output vectors - static_for<0, num_vec_out, 1>{}([&](auto i) { - constexpr auto idx_y_out_tmp = generate_array( - [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, - number{}); + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); - constexpr auto idx_y_out = - container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_tuple( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); - constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); - static_assert(out_offset % vec_length_out == 0); + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); - out_tensor.get_thread_buffer().template set_as( - number{}, - out_vectors[i].template get_as()[I0]); + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer() + .template get_as()[number{}]; + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { + return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); + }); }); - }); + } } } // namespace detail From 6ce0797dadfc6d0c6cdde3e01532e90137fc5b0c Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 8 Apr 2025 09:00:51 -0700 Subject: [PATCH 029/443] simplify generate_tuple (#2043) --- include/ck/utility/sequence.hpp | 15 +++++++++++++++ include/ck/utility/tuple_helper.hpp | 9 +++++++-- include/ck_tile/core/container/tuple.hpp | 9 +++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 99935a6d8d..497625f7e2 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -184,6 +184,21 @@ struct Sequence } }; +namespace impl { +template +struct __integer_sequence; + +template +struct __integer_sequence +{ + using seq_type = Sequence; +}; +} // namespace impl + +template +using make_index_sequence = + typename __make_integer_seq::seq_type; + // merge sequence template struct sequence_merge diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index b4f1545aa9..b1a0c1fc5d 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -11,11 +11,16 @@ namespace ck { +template +__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence) +{ + return make_tuple(f(Number{})...); +} + template __host__ __device__ constexpr auto generate_tuple(F&& f, Number) { - return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, - typename arithmetic_sequence_gen<0, N, 1>::type{}); + return generate_tuple_for(f, make_index_sequence{}); } template diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index fd02177e25..3700d348e7 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -396,11 +396,16 @@ struct tuple_array_impl }; } // namespace impl +template +CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence) +{ + return make_tuple(f(number{})...); +} + template CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number) { - return unpack([&f](auto&&... is) { return make_tuple(f(is)...); }, - typename arithmetic_sequence_gen<0, N, 1>::type{}); + return generate_tuple_for(f, make_index_sequence{}); } template From b12cd6580b9737a9e8c6c055b25babc579242184 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Apr 2025 09:06:38 -0700 Subject: [PATCH 030/443] Bump rocm-docs-core from 1.18.1 to 1.18.2 in /docs/sphinx (#2047) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.18.1 to 1.18.2. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.18.1...v1.18.2) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-version: 1.18.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 2fcf3b3935..b89cb9fec8 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.18.1 +rocm-docs-core==1.18.2 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 12572d400e..2a52a48e4c 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -199,7 +199,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.18.1 +rocm-docs-core==1.18.2 # via -r requirements.in rpds-py==0.22.3 # via From 2c8132126ce089885d7aca40bc277196d8e78b34 Mon Sep 17 00:00:00 2001 From: spolifroni-amd Date: Tue, 8 Apr 2025 13:20:31 -0400 Subject: [PATCH 031/443] fixed broken github link (#2063) --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 15a9321d43..6d46eb49b1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,7 +10,7 @@ Composable Kernel User Guide The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. -The Composable Kernel repository is located at `https://github.com/ROCm/composable-kernel `_. +The Composable Kernel repository is located at `https://github.com/ROCm/composable_kernel `_. .. grid:: 2 :gutter: 3 From 263ff689e0cc03f9772f6a76eca57258db48698e Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 8 Apr 2025 15:14:53 -0700 Subject: [PATCH 032/443] New instances for gemm_multiply_multiply_weightpreshuffle operator (#2061) * Add new instances for weight_preshuffle for f8->bf16 * Add new instances for weight_preshuffle for f8->f16 * clang formatted --------- Co-authored-by: Khushbu Agarwal Co-authored-by: Thomas Ning --- ...ultiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp | 12 +++++++++++- ...multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp | 8 +++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp index 4266ab9aa3..e5ada03a46 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp @@ -100,7 +100,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_ //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp index 94e44ee600..dc9db8889a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp @@ -115,7 +115,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on >; From 2c563fecf76eeecd49a28950ca601ff5ba5a735f Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Wed, 9 Apr 2025 06:16:30 +0800 Subject: [PATCH 033/443] add passthrough for int32->float32 (#2062) --- .../gpu/element/unary_element_wise_operation.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index f602e36e73..672998d811 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -357,6 +357,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const int32_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { From 03ce8729fd52ce1e8e8c4290d5d1ea79ec12ffa4 Mon Sep 17 00:00:00 2001 From: MHYang-gh Date: Wed, 9 Apr 2025 06:34:11 +0800 Subject: [PATCH 034/443] Make buffer coherence configurable in tensor view (#2041) * Make buffer coherence configurable in tensor view * Fix clang-format for tensor_view.hpp --- include/ck_tile/core/tensor/tensor_view.hpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 336793c5b1..32de227b52 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -411,18 +411,21 @@ struct null_tensor_view }; template CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, const tensor_descriptor& desc) { - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } template {}, number{}); - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } template @@ -458,7 +463,8 @@ make_naive_tensor_view_packed(DataType* p, auto desc = make_naive_tensor_descriptor_packed(lengths, number{}); - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } From 3e6d21adeb33db1319899a3833113c9caf715358 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:06:42 -0700 Subject: [PATCH 035/443] enable gfx115x support (#2065) --- example/CMakeLists.txt | 8 ++++---- include/ck/ck.hpp | 3 ++- include/ck/host_utility/device_prop.hpp | 4 +++- include/ck_tile/core/config.hpp | 3 ++- .../src/tensor_operation_instance/gpu/CMakeLists.txt | 12 ++++++------ test/CMakeLists.txt | 10 +++++----- 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 64ff2a6813..996a543ecc 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -114,14 +114,14 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 message("trimming targets for ${FILE_NAME}") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -212,7 +212,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 9d5d5fbc0b..0c2dc799ab 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -65,7 +65,8 @@ #define __gfx103__ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 3323ab6c7b..5439bbe1f0 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -86,7 +86,9 @@ inline bool is_gfx103_supported() inline bool is_gfx11_supported() { return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || + ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || + ck::get_device_name() == "gfx1152"; } inline bool is_gfx12_supported() diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 978f673346..414509e479 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -15,7 +15,8 @@ #define __gfx103__ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index a16418ec7e..2542dd236b 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -95,26 +95,26 @@ function(add_instance_library INSTANCE_NAME) foreach(source IN LISTS ARGN) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(source MATCHES "_xdl") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(source MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() else() if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() endif() set(offload_targets) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 38fbf5385f..18611d8052 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -101,11 +101,11 @@ function(add_test_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -197,13 +197,13 @@ function(add_gtest_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) From f14e648e7ca69c161c8910778e50c4c3a9d63f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Manuel=20Martinez=20Caama=C3=B1o?= Date: Thu, 10 Apr 2025 09:48:37 +0200 Subject: [PATCH 036/443] Replace inline assembly with builtins in FHMA (#2067) * Replace inline assembly with builtins in FHMA --------- Co-authored-by: illsilin --- .../core/arch/amd_buffer_addressing.hpp | 174 +++++++++++++++--- 1 file changed, 153 insertions(+), 21 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 33faa3a18b..5d6d6ce348 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -14,6 +14,15 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +// This attribute gives a hint to the compiler that a branch is likely to be taken. +// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would +// have been generated. +#if __cplusplus >= 202002L +#define LIKELY(x) (x) [[likely]] +#else +#define LIKELY(x) (__builtin_expect(!!(x), 1)) +#endif + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -58,10 +67,36 @@ template<> struct buffer_load_trait<4 , thread_buffer> { using payloa // TODO: glc/slc/... template struct buffer_load; + +template +struct buffer_load_if; + +template +struct buffer_store; + +template +struct buffer_store_if; + #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) + +#define HAS_RAW_BUFFER_BUILTINS \ + __has_builtin(__builtin_amdgcn_raw_buffer_load_b32) && \ + __has_builtin(__builtin_amdgcn_make_buffer_rsrc) && \ + __has_builtin(__builtin_amdgcn_raw_buffer_store_b32) + +#if HAS_RAW_BUFFER_BUILTINS +CK_TILE_DEVICE __amdgpu_buffer_rsrc_t cast_to_amdgpu_buffer_rsrc_t(int32x4_t res) +{ + __amdgpu_buffer_rsrc_t as_rsrc; + static_assert(sizeof(res) == sizeof(as_rsrc) && "Size of buffer resource should match"); + memcpy(&as_rsrc, &res, sizeof(res)); + return as_rsrc; +} +#endif + template struct buffer_load<16, pre_nop> { @@ -76,6 +111,11 @@ struct buffer_load<16, pre_nop> { static_assert(sizeof(T) == 16); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b128( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" @@ -87,6 +127,7 @@ struct buffer_load<16, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -104,6 +145,11 @@ struct buffer_load<8, pre_nop> { static_assert(sizeof(T) == 8); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b64( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" @@ -115,6 +161,7 @@ struct buffer_load<8, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -132,6 +179,12 @@ struct buffer_load<4, pre_nop> { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b32( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dword %0, %1, %2, 0 offen offset:%3" @@ -143,6 +196,7 @@ struct buffer_load<4, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -160,6 +214,12 @@ struct buffer_load<2, pre_nop> { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b16( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" @@ -171,6 +231,7 @@ struct buffer_load<2, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -188,6 +249,11 @@ struct buffer_load<1, pre_nop> { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b16( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" @@ -199,12 +265,31 @@ struct buffer_load<1, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; -template -struct buffer_load_if; - +#if HAS_RAW_BUFFER_BUILTINS +template +struct buffer_load_if +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0, + bool_constant = {}) + { + if LIKELY(1 <= flag) + { + buffer_load{}( + value, res, v_offset, s_offset, i_offset, flag, bool_constant{}); + } + } +}; +#else template struct buffer_load_if<16, pre_nop> { @@ -214,12 +299,12 @@ struct buffer_load_if<16, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); if constexpr(pre_nop) asm volatile("s_nop 4\n" @@ -248,12 +333,12 @@ struct buffer_load_if<8, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -281,12 +366,12 @@ struct buffer_load_if<4, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -314,12 +399,12 @@ struct buffer_load_if<2, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -347,12 +432,12 @@ struct buffer_load_if<1, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -370,9 +455,9 @@ struct buffer_load_if<1, pre_nop> : "memory"); } }; +#endif + #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" -template -struct buffer_store; template <> struct buffer_store<16> @@ -387,10 +472,16 @@ struct buffer_store<16> { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b128( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -407,10 +498,16 @@ struct buffer_store<8> { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b64( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -427,10 +524,16 @@ struct buffer_store<4> { static_assert(sizeof(T) == 4); using mbuf_t = float; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b32( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -447,10 +550,16 @@ struct buffer_store<2> { static_assert(sizeof(T) == 2); using mbuf_t = short; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b16( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -467,16 +576,38 @@ struct buffer_store<1> { static_assert(sizeof(T) == 4); using mbuf_t = float; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b8( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; +#if HAS_RAW_BUFFER_BUILTINS template -struct buffer_store_if; - +struct buffer_store_if +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + if LIKELY(1 <= flag) + { + buffer_store{}(value, res, v_offset, s_offset, i_offset); + } + } +}; +#else template <> struct buffer_store_if<16> { @@ -490,7 +621,7 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -547,7 +678,7 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -575,7 +706,7 @@ struct buffer_store_if<2> { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -603,7 +734,7 @@ struct buffer_store_if<1> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -617,6 +748,7 @@ struct buffer_store_if<1> : "memory"); } }; +#endif CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) { From 5f885d2b7af1e6b2f40eefa9126f58e93e164e6d Mon Sep 17 00:00:00 2001 From: slippedJim Date: Thu, 10 Apr 2025 23:21:13 +0800 Subject: [PATCH 037/443] add fmha fwd splitkv receipt for aiter c++ api (#2068) * add s_randval for c++ api * Fix bug of bias in splitkv --------- Co-authored-by: rocking --- example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 13 +++++++++++-- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 5 ++--- .../ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 12 ++++++++++++ example/ck_tile/01_fmha/generate.py | 4 ++-- .../ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 6 +++--- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 94f89256f9..1e6755c631 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -545,10 +545,9 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue + # aiter::mha_bwd C++ api integration elif receipt == 600: cond = dtype in ['fp16', 'bf16'] - cond &= mode in ["batch", "group"] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] cond &= dpad == dvpad if not cond: continue @@ -689,6 +688,11 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB cond &= mode == "group" if not cond: continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -841,6 +845,11 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh cond &= mode == "group" if not cond: continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d978cc1d9b..10a6e5c1d7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -536,10 +536,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue - # Aiter aiter::mha_fwd integration - elif receipt == 500: + # aiter::mha_fwd C++ api integration + elif receipt == 600: cond = dtype in ['fp16', 'bf16'] - cond &= mode in ['batch', 'group'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_squant == 'f' if not cond: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c6d1a01792..0dccdf6bd6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -738,6 +738,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -796,6 +803,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis cond &= mode == "group" if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0d35db14d4..25931da141 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -109,8 +109,8 @@ if __name__ == "__main__": " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" - + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" ) args = parser.parse_args() diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 143abe8048..ea1762abc1 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -95,8 +95,8 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ #undef _TS_ @@ -563,7 +563,7 @@ struct FmhaFwdSplitKVKernel } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias + key_start; + batch_offset_bias = query_start * kargs.stride_bias; } batch_offset_lse_acc = query_start; From 6c61f4d237a9841c5b5d8b4380eaf9c2af14947e Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 11 Apr 2025 12:18:26 +0200 Subject: [PATCH 038/443] [CK_TILE] Add 2:4 structured sparsity support for fp16 gemm (#1957) * add structured sparsity fp16 support for gemm * added reviewer suggestions * update changelog * update changelog * add reviewers suggestions * Minor fix * clang fix * fix doxygen --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 3 +- example/ck_tile/03_gemm/run_gemm_example.inc | 24 ++-- example/ck_tile/03_gemm/universal_gemm.cpp | 3 +- include/ck_tile/host/fill.hpp | 43 +++++++ .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 4 +- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 8 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 13 +- .../gemm/warp/warp_gemm_attribute_smfmac.hpp | 80 ++++++++++++ .../warp/warp_gemm_attribute_smfmac_impl.hpp | 114 ++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 15 ++- .../ops/gemm/warp/warp_gemm_smfmac_impl.hpp | 110 +++++++++++++++++ 13 files changed, 401 insertions(+), 20 deletions(-) create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ef2998eb..e3d7971c71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for FP16 2:4 structured sparsity to universal GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 3254a407fd..973006196b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -93,7 +93,8 @@ struct GemmConfig static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c3b4ec609c..b4ea5d22c0 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -55,7 +55,8 @@ void permute_tensor_b(Tensor& tensor) ALayout, BLayout, CLayout, - GemmConfig::TransposeC>; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; return ave_time; } @@ -259,6 +262,11 @@ int run_gemm_example_with_layouts(int argc, b_k_n.SetZero(); } + if(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eef8d3b60e..2ba16ca89d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -46,7 +46,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ALayout, BLayout, CLayout, - GemmConfig::TransposeC>; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 006026470b..d90c0cf6cf 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -364,6 +364,49 @@ struct FillConstant } }; +//---------------------------------------------------------------------------------------------- +/// @brief Transforms given input to fit 2:4 structured sparsity pattern so +/// every subgroup of 4 elements contain at most 2 non-zero elements +template +struct AdjustToStructuredSparsity +{ + size_t start{0}; + // masks represent all valid 2:4 structured sparsity permutations + // clang-format off + static constexpr int32_t masks[] = {0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + 0, 0, 0, 1, + 0, 0, 1, 0, + 0, 1, 0, 0, + 1, 0, 0, 0}; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::transform(first, last, first, [=, index = start](T val) mutable { + auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))]; + index += 1; + + return type_convert(tmp); + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillTrigValue { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index f833ccc849..cba3677332 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -194,7 +194,8 @@ struct UniversalGemmPipelineProblem static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; - static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c504a51ad0..b555cf75e0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -580,7 +580,9 @@ struct UniversalGemmPipelineAgBgCrPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - Problem::TransposeC>; + Problem::TransposeC, + false, + Problem::UseStructuredSparsity>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy + bool TransposeC_ = false, + bool UseStructuredSparsity_ = false> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -49,7 +50,8 @@ struct TileGemmUniversalTraits using BLayout = BLayout_; using CLayout = CLayout_; - static constexpr bool TransposeC = TransposeC_; + static constexpr bool TransposeC = TransposeC_; + static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 1fd12973f6..33f3dde256 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,9 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" + namespace ck_tile { // fp16 @@ -64,6 +67,14 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; +// fp16 2:4 structured sparsity + +using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; + +using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; + // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp new file mode 100644 index 0000000000..adf548aaca --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -0,0 +1,80 @@ +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" + +namespace ck_tile { + +/** + * @brief Class describing structured sparsity mfma instructions. + * + * @paragraph Overview "Overview" + * Currently only 2:4 structured sparsity is supported, which is based on requirement that in every + * groups of four continuous elements there are at most two non-zero, which results in processing + * only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac + * instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of + * non-zero elements are stored in `index` which is an additional parameter to assembly instruction. + * Every pair of two bit indexes are containing information about which two elements in current + * group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity + * format is supported only for A matrix for now. + */ +template +struct WarpGemmAttributeSmfmac +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using IdxDataType = typename Impl::IdxDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCompressionRatio = Impl::CompressionRatio; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeSmfmacImpl is not supported"); + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { + Impl{}(c_vec, a_vec, b_vec, idx, bool_constant{}); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp new file mode 100644 index 0000000000..97fd2a8742 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "warp_gemm_attribute_mfma_impl.hpp" + +namespace ck_tile { + +// fp16 2:4 structured sparsity + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx9__) + c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9c319b5e5f..6320b33598 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> struct WarpGemmMfmaDispatcher; // clang-format off @@ -35,6 +36,10 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +// fp16 2:4 structural sparsity +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; + // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; @@ -70,7 +75,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + SwizzleA, + UseStructuredSparsity>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp new file mode 100644 index 0000000000..9e028ddab0 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +namespace ck_tile { + +template +struct WarpGemmSmfmacImpl +{ + using WarpGemmAttribute = remove_cvref_t; + + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + /// @brief The number of elements in K dimension processed by single thread in wavefront. + /// + /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). + /// In such situation this value reflects this fact. + static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread; + + using ADataType = typename WarpGemmAttribute::ADataType; + using BDataType = typename WarpGemmAttribute::BDataType; + using CDataType = typename WarpGemmAttribute::CDataType; + + using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; + using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; + + using AWarpDstr = remove_cvref_t; + using BWarpDstr = remove_cvref_t; + using CWarpDstr = remove_cvref_t; + + using AWarpTensor = static_distributed_tensor; + using BWarpTensor = static_distributed_tensor; + using CWarpTensor = static_distributed_tensor; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() + { + return WarpGemmAttribute_::get_num_of_access(); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + /// elements into lower part of a_vec to half its effective size. + /// + /// @param a_vec Vector to be compressed. + /// + /// @return Four 2-bit indexes of non-zero elements locations + /// + template + CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const + { + int32_t idx = 0b11101110; + + static_for<0, 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; + } + + template + CK_TILE_DEVICE void + operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; + + using AVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + const int32_t idx = compress_a(a_vec); + + // @TODO can we simply set a_vec_pruned to a_vec[0:3]? + const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; + + // c_vec += a_vec * b_vec[idx] + WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } +}; + +} // namespace ck_tile From 74fda2e796fbdce6688882347c12a3710eeef250 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Fri, 11 Apr 2025 10:17:29 -0700 Subject: [PATCH 039/443] CkProfiler StreamK GemmUniversal Fix and Split Gemm_universal Test Redo PR #2044 (#2070) * fix and split gemm_universal test * Update test_gemm_universal_streamk_ut_cases_fp8.inc --- .../profile_gemm_universal_streamk_impl.hpp | 2 +- test/gemm_universal/CMakeLists.txt | 15 ++- ... => test_gemm_universal_ut_cases_bf16.inc} | 60 +++------- .../test_gemm_universal_ut_cases_fp16.inc | 113 ++++++++++++++++++ .../test_gemm_universal_ut_cases_fp8.inc | 113 ++++++++++++++++++ ...l.cpp => test_gemm_universal_xdl_bf16.cpp} | 34 ++---- .../test_gemm_universal_xdl_fp16.cpp | 82 +++++++++++++ .../test_gemm_universal_xdl_fp8.cpp | 71 +++++++++++ ...t_gemm_universal_streamk_ut_cases_fp16.inc | 28 ----- ...st_gemm_universal_streamk_ut_cases_fp8.inc | 28 ----- .../test_gemm_universal_streamk_util.hpp | 12 +- 11 files changed, 423 insertions(+), 135 deletions(-) mode change 100644 => 100755 profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp mode change 100644 => 100755 test/gemm_universal/CMakeLists.txt rename test/gemm_universal/{test_gemm_universal_ut_cases.inc => test_gemm_universal_ut_cases_bf16.inc} (75%) create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc rename test/gemm_universal/{test_gemm_universal_xdl.cpp => test_gemm_universal_xdl_bf16.cpp} (61%) create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp16.cpp create mode 100644 test/gemm_universal/test_gemm_universal_xdl_fp8.cpp diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100644 new mode 100755 index d145ab1766..e625fae808 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size != -1) + if(Grid_size == -1) { grid_size_list = {Grid_size}; } diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100644 new mode 100755 index 4aab6323cc..cf5c68e220 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,4 +1,15 @@ -add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) +add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) endif() + +add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) +endif() + diff --git a/test/gemm_universal/test_gemm_universal_ut_cases.inc b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases.inc rename to test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc index 9a21666856..8a6c672a9f 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } - -TYPED_TEST(TestGemmUniversal_KM_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} - -TYPED_TEST(TestGemmUniversal_KM_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc new file mode 100644 index 0000000000..6f6d550625 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc new file mode 100644 index 0000000000..b831e15e9c --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl.cpp rename to test/gemm_universal/test_gemm_universal_xdl_bf16.cpp index b872d7089a..8fde65657a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp @@ -7,8 +7,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" -using F8 = ck::f8_t; -using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -29,25 +27,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_MK_KN +class TestGemmUniversal_BF16_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_MK_NK +class TestGemmUniversal_BF16_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_KN +class TestGemmUniversal_BF16_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_NK +class TestGemmUniversal_BF16_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; @@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases.inc" +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp new file mode 100644 index 0000000000..24f587daf6 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp new file mode 100644 index 0000000000..e833ab7825 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; + + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + + +#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc index b2fdfe8193..99c8e6d163 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc @@ -28,34 +28,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, PaddK) { std::vector Ms{127}; diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc index b3da08f703..b98ee92800 100755 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc @@ -28,34 +28,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, PaddK) { std::vector Ms{127}; diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index ef3509c0ca..805587a274 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - for(auto grid_size : grid_size_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); - } + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); + } } void RunSingle(const int M, From 0d4f14507818d118696fc345a3a7623b20470c4e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 11 Apr 2025 12:12:53 -0700 Subject: [PATCH 040/443] Fix build issues for multiple targets. (#2077) * build for multiple targets on gfx942 * add missing ignore statements --- Jenkinsfile | 28 +++---------------- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 10 +++++++ 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index dbd484d7bd..d105e385ab 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1112,7 +1112,7 @@ pipeline { beforeAgent true expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } + agent{ label rocmnode("gfx942") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ @@ -1128,26 +1128,6 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx942") - { - when { - beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } - } - agent{ label rocmnode("gfx942") } - environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ - cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx942" \ - -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ - } - steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') - cleanWs() - } - } stage("Build CK and run Tests on gfx908") { when { @@ -1194,13 +1174,13 @@ pipeline { beforeAgent true expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } + agent{ label rocmnode("gfx942") } environment{ execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ - -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j32 """ + -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f6ea23a1e7..d56c7abcde 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -81,6 +81,11 @@ __global__ void k_idx); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; #endif // end of if (defined(__gfx9__) } @@ -140,6 +145,11 @@ __global__ void k_idx); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; #endif // end of if (defined(__gfx9__) } From 269f4f6af5aba8c8ac6fe215fcf6ea604dc6b101 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Sun, 13 Apr 2025 20:09:30 -0700 Subject: [PATCH 041/443] Solve the Static Encoding Pattern compile error when the tile size is too small (#2079) --- include/ck_tile/core.hpp | 1 + .../algorithm/static_encoding_pattern.hpp | 27 ++++++++++--------- include/ck_tile/ops/epilogue.hpp | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index d9aa8b3551..821b3a8e84 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -59,6 +59,7 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 78884f3f9f..b56bda3741 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -73,10 +73,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim // # of rows in Y dim accessed by single wavefront in one iteration static constexpr index_t Y1 = warp_size / X0; @@ -124,10 +125,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); @@ -173,10 +175,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); static constexpr index_t Y1 = num_warps; diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 12e53e13e6..6cc0fa8540 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -4,9 +4,9 @@ #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" -#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" From 56378f810fdd328fec449e6574af656148e4c894 Mon Sep 17 00:00:00 2001 From: Mingtao Gu <145657261+mtgu0705@users.noreply.github.com> Date: Mon, 14 Apr 2025 16:58:57 +0800 Subject: [PATCH 042/443] CK pk_i4_t test failures fix (SWDEV-518629) (#2075) * fix pk_i4_v3 tests failures in Unbuntu env. * fix pk_i4_t tests failure on Unbuntu issues. * some fixed. --------- Co-authored-by: mtgu0705 --- example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp | 12 +++++++--- example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp | 12 +++++++--- .../gemm_xdl_fp16_pk_i4_v3_b_scale.cpp | 12 +++++++--- .../gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp | 13 ++++++++--- example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp | 12 +++++++--- ..._batched_gemm_example_fp16int4_b_scale.inc | 3 ++- .../moe_gemm1_xdl_pk_i4.cpp | 11 +++++++--- .../moe_gemm2_xdl_pk_i4.cpp | 11 +++++++--- ...evice_batched_gemm_xdl_fpAintB_b_scale.hpp | 16 ++++++++++---- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 22 +++++++++++++++---- ...vice_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 22 +++++++++++++++---- .../device_gemm_xdl_cshuffle_v3_b_scale.hpp | 22 +++++++++++++++---- .../gpu/device/impl/device_moe_gemm.hpp | 22 +++++++++++++++---- 13 files changed, 148 insertions(+), 42 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp index 7c232f1bcf..7178ad46b9 100644 --- a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -192,14 +192,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index 61c5a32d5d..e16f184a20 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -134,7 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -242,14 +242,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp index 468dd699a1..f83d479713 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -161,7 +161,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -274,14 +274,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index 80f7e95d30..266a1e9d3e 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -152,7 +152,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize() / + 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // do GEMM @@ -261,14 +262,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index 7b72461dd9..0575314dff 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -132,7 +132,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -240,14 +240,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 8c4913dbcc..3582bc5e33 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -212,7 +212,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() / + 2); DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 1102ce1054..a25d1b5fa3 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -301,7 +301,7 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); @@ -440,13 +440,18 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || - !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 528503a2c4..8c2c70b4a1 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -298,7 +298,7 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); @@ -407,13 +407,18 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || - !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { // not result correct here because output buf not setzero diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 963f0edd08..7d9555dc82 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -224,12 +224,20 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale PermuteA, PermuteB>; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + static constexpr index_t BPackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; }(); + struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -352,10 +360,10 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - auto size_a_buffer = - a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); - auto size_b_buffer = - b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 51c223efd2..dde21725d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -229,6 +229,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + /// @brief Helper structure responsible for kernel invocation. /// /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU @@ -278,10 +292,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 58a182b924..faa235be50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -130,6 +130,20 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + int GetPreShuffleParameters() override { return NPerXDL; } // Invoker @@ -168,10 +182,10 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp index 044350d11c..456e5e90d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -139,6 +139,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + // Invoker struct Invoker : public BaseInvoker { @@ -174,10 +188,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 950fe0236d..f3fc1aaa9f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -139,6 +139,20 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + int GetPreShuffleParameters() override { return NPerXDL; } // Invoker @@ -179,10 +193,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle Date: Mon, 14 Apr 2025 16:41:47 -0700 Subject: [PATCH 043/443] Upgrade default docker image to ROCm6.4 release. (#2082) * upgrade to rocm6.4 * fix gfx10 generic target syntax * use gfx1101 target for unit tests * use gfx1201 target for unit tests * do not use generic targets until 6.4.1 release * update target list and dockerfile.compiler --- Dockerfile | 15 +++++++-------- Dockerfile.compiler | 2 +- Jenkinsfile | 14 +++++++------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/Dockerfile b/Dockerfile index 17800d92d5..2a8fb707c9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:22.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.3 +ARG ROCMVERSION=6.4 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -13,15 +13,15 @@ RUN set -xe && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.4" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ +RUN if [ "$ROCMVERSION" != "6.5" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ fi -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" && \ +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ amdgpu-install -y --usecase=rocm --no-dkms ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined @@ -51,7 +51,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- mpich \ net-tools \ pkg-config \ - python \ python3 \ python3-dev \ python3-pip \ @@ -99,7 +98,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results pip3 install --upgrade pip && \ - pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools>=75 sshtunnel==0.4.0 && \ + pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/Dockerfile.compiler b/Dockerfile.compiler index a22103b96b..f4aa12f356 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.4" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index d105e385ab..e6256fc3d8 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,7 +39,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = "${params.ROCMVERSION}" as float - if ( ROCM_numeric < 6.4 ){ + if ( ROCM_numeric < 6.5 ){ img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}" } else{ @@ -519,13 +519,13 @@ def Build_CK(Map conf=[:]){ else if ( runShell('grep -n "gfx942" rocminfo.log') ) { arch_type = 2 } - else if ( runShell('grep -n "gfx1030" rocminfo.log') ) { + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { arch_type = 3 } - else if ( runShell('grep -n "gfx1101" rocminfo.log') ) { + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { arch_type = 4 } - else if ( runShell('grep -n "gfx1201" rocminfo.log') ) { + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { arch_type = 5 } else if ( runShell('grep -n "gfx908" rocminfo.log') ) { @@ -744,8 +744,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true; +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true; 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false @@ -770,7 +770,7 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.3', + defaultValue: '6.4', description: 'Specify which ROCM version to use: 6.3 (default).') string( name: 'COMPILER_VERSION', From 7106976a72897f44b05260bd1ae1f70b319a4e75 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:17:07 -0600 Subject: [PATCH 044/443] MX GEMM - New GEMM pipeline for MX data types (#2059) * Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods --- CHANGELOG.md | 1 + example/67_gemm_microscaling/CMakeLists.txt | 9 +- example/67_gemm_microscaling/README.md | 8 +- .../67_gemm_microscaling/gemm_mx_common.hpp | 79 +-- example/67_gemm_microscaling/gemm_mx_fp8.cpp | 98 ++++ .../gemm_mx_fp8_e8m0_scale.cpp | 42 -- .../gemm_mx_fp8_fp16_scale.cpp | 42 -- .../gemm_mx_fp8_fp8_scale.cpp | 42 -- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 363 ++++++++++++ ...kwise_gemm_pipeline_xdlops_mx_selector.hpp | 35 +- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 546 +++++++++--------- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 14 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 122 ++-- .../threadwise_tensor_slice_transfer.hpp | 3 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 89 ++- include/ck/utility/amd_xdlops.hpp | 16 +- include/ck/utility/e8m0.hpp | 4 +- include/ck/utility/mxfp_utils.hpp | 4 +- test/mx_mfma_op/mx_mfma_op.hpp | 98 ++-- 19 files changed, 1007 insertions(+), 608 deletions(-) create mode 100644 example/67_gemm_microscaling/gemm_mx_fp8.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d7971c71..b9012c0a77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added GEMM pipeline for microscaling (MX) data types * Added support for FP16 2:4 structured sparsity to universal GEMM. ### Optimized diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 9e95c3e007..93770684df 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,10 +1,5 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) -add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) - -add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index 713902588d..57b6490eda 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -10,16 +10,16 @@ Custom verification parameters: # arg4: verbosity (0=no info, 1=verbose info) # arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC # arg11: KBatch -./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 +./bin/example_gemm_mx_fp8 1 1 0 1 ``` Custom tensor shapes: ```bash -./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1 ``` Default invocation: ```bash -# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 -./bin/example_gemm_mx_fp8_fp8_scale +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 9a05954c73..32ef975192 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -95,7 +95,7 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl + << "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl << "arg11: KBatch" << std::endl; return false; } @@ -103,7 +103,8 @@ bool parse_cmd_args(int argc, return true; } -template + ck::index_t ScaleBlockSize> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; - - static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - - static constexpr ck::index_t KPerBlock = 64; - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - ADataType, // ADataType - XDataType, // AScaleDataType - BDataType, // BDataType - XDataType, // BScaleDataType - CDataType, // CDataType - AccDataType, // GemmAccDataType - CShuffleDataType, // CShuffleDataType - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - GemmSpec, // GemmSpec - MXVectorSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock - KPerBlock, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - BlkGemmPSched, // BlkGemmPipeSched - BlkGemmPVer, // BlkGemmPipelineVer - ADataType, // ComputeTypeA - BDataType // ComputeTypeB - >; auto M = problem_size.M; auto N = problem_size.N; @@ -230,8 +175,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor a_m_k_scale(f_host_tensor_descriptor( M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A @@ -428,8 +373,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + - std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of + // partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; @@ -445,7 +392,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c return res_verified; } -template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp deleted file mode 100644 index 393f4a2ea7..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::e8m0_bexp_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp deleted file mode 100644 index dd654a8f69..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::half_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp deleted file mode 100644 index c42d9783be..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::f8_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp new file mode 100644 index 0000000000..ebe075b55d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_mx_pipeline_base +{ + using ComputeTypeA = ADataType; + using ComputeTypeB = BDataType; + using AccType = float; // for now only support V_MFMA_SCALE_F32 + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + //> store rows/cols into thread registers in chunks of 16 + //> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] + static constexpr index_t KThreadChunk = 16; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPerInnerLoop = KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp index 24f6afc381..c1433659d6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -7,6 +7,35 @@ namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} + template constexpr auto BlockGemmMXPipeline_Selector() { + + // Hardware MX GEMM pipeline if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { return BlockwiseGemmXdlops_pipeline_v1_mx - : BlockwiseGemmXdlops_pipeline_base + : BlockwiseGemmXdlops_mx_pipeline_base { - using Base = BlockwiseGemmXdlops_pipeline_base; + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; using Base::I0; using Base::I1; using Base::KRepeat; @@ -134,7 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; __host__ static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -172,45 +173,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( @@ -276,49 +238,31 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, @@ -326,15 +270,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -345,8 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx(); - // main body if constexpr(HasMainLoop) { @@ -363,141 +322,166 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // MFMA accumulation - // m = 1:MPerXDL - // n = 1:NPerXDL - // k = 1:KPack - // c(m,n) += a(m,k)*b(k,n) xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); - - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}( - [&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}( - [&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(m0, k0, g, r)); - - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); - }); + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); + // Prefetch a_scales static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + // Prefetch b_scales static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + + // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -507,7 +491,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); } } - // TODO: make this field protected when a_scale_thread_copy_ is moved here + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})); + make_tuple(Number{}, Number{}, Number{})); // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - // TODO: make this field protected when b_scale_thread_copy_ is moved here - static constexpr auto b_scale_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); protected: using Base::a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 34df9a1d7b..8a370304c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -694,14 +694,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX || is_same_v || - is_same_v || is_same_v || - is_same_v)&&(is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v), + static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), "Only microscaling formats are supported for ADataType and BDataType"); static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); @@ -711,6 +704,11 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX{}; static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) - ? true - : false; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; @@ -1088,10 +1094,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(KPerBlock % ScaleBlockSize == 0, "KPerBlock should be multiple of ScaleBlockSize"); - static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat, - "Single call to xdlops_gemm::Run should process exactly ScaleBlockSize " - "elements in k dimension"); - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1476,61 +1478,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - static constexpr auto KPerThread = KPerBlock / K0PerXdlops; - - // NXdlPerWave == NRepeat - // MXdlPerWave == MRepeat - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - - // Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2 + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - auto a_thread_offset_m = - MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) + - mfma.selected_mfma.group_size * - ((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl); - auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl; + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + - (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; - auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl; + // TODO: Document initial thread mapping for more combinations of parameters - auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - AScaleDataType, - AScaleDataType, - decltype(a_scale_grid_desc_am_ak), // SrcDesc - decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc - Sequence, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 0, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>(a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, - a_thread_offset_k / ScaleBlockSize)); + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - BScaleDataType, - BScaleDataType, - decltype(b_scale_grid_desc_bn_ak), - decltype(BlockwiseGemmPipe::b_scale_thread_desc), - Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector - 1, - false>(b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, - b_thread_offset_k / ScaleBlockSize)); + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + mfma.selected_mfma.num_threads_per_blk; + + auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 0310fe37a0..2255505985 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -211,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source * tensor. * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. - * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source - * tensor. + * @tparam SrcScalarStrideInVector Not used. * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run * or rolled back one step in MoveSrcSliceWindow * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index a638ca8608..529a1a1729 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -845,15 +845,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -874,15 +883,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -890,14 +908,16 @@ template + bool is_single_rate_mfma = false, + bool is_scale_mfma = false> struct MfmaSelector { template + bool is_single_rate_mfma_ = false, + bool is_scale_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -1103,12 +1123,24 @@ struct MfmaSelector return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { @@ -1145,8 +1177,12 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = mfma_type< - GetMfma()>{}; + static constexpr auto selected_mfma = mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1194,7 +1230,8 @@ template + bool TransposeC = false, + bool is_scale_mfma = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -1225,7 +1262,7 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); } // XDL output supporting C = A * B @@ -1368,6 +1405,27 @@ struct XdlopsGemm }); } + template + __device__ void Run(const FloatA& p_a_wave, + const ScaleA& a_scale_thread, + const FloatB& p_b_wave, + const ScaleB& b_scale_thread, + FloatC& p_c_thread) const + { + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread); + } + }); + } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } __device__ static auto GetBlkIdx() @@ -1455,7 +1513,8 @@ struct XdlopsGemm KPack <= 4) || (is_same::value && KPack <= 8)) ? true - : false > {}; + : false, + is_scale_mfma > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0d4611becc..a54a181bf1 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -520,9 +520,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) @@ -538,6 +538,14 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> scale_a, 0, // OPSEL scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. #else ignore = reg_a; ignore = scale_a; @@ -556,9 +564,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp index a692f533f8..f7d2a2f594 100644 --- a/include/ck/utility/e8m0.hpp +++ b/include/ck/utility/e8m0.hpp @@ -67,10 +67,10 @@ struct e8m0_bexp_t namespace utils { template -__host__ __device__ inline int get_exponent_value(T x); +__host__ __device__ inline constexpr int32_t get_exponent_value(T x); template <> -__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +__host__ __device__ inline constexpr int32_t get_exponent_value(e8m0_bexp_t x) { return x.data; } diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index f0a86f8750..cf7a3e8713 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -32,13 +32,13 @@ template __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data); template -__host__ __device__ inline int get_exponent_value(T x) +__host__ __device__ inline constexpr int32_t get_exponent_value(T x) { x >>= NumericUtils::mant; x &= ((1 << NumericUtils::exp) - 1); - return static_cast(x); + return static_cast(x); } template diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 1f9091ebc5..d22157c3b3 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -30,48 +30,69 @@ enum class MFMA_F8F6F4 }; -template +template struct mfma_type_selector; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<16, 16> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); - } - - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, - BFragT const& fragB, - const int32_t scale_b, - AccumFragT& fragAcc) - { - auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<16, 16>(fragA, fragB, fragAcc); } }; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<32, 32> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); + op.template run<32, 32>(fragA, fragB, fragAcc); } +}; - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, +template +struct mfma_scale_type_selector; + +template <> +struct mfma_scale_type_selector<16, 16> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, BFragT const& fragB, - const int32_t scale_b, + BScaleFragT const& scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); + } +}; + +template <> +struct mfma_scale_type_selector<32, 32> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, + BFragT const& fragB, + BScaleFragT const& scale_b, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); } }; @@ -334,8 +355,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // BLOCK_K / BLOCK_X is a stride in xA matrix auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); - // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_A_row_major(input_ptr); } @@ -502,7 +522,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_B_col_major(input_ptr); } @@ -773,7 +793,8 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}(fragA, fragB, fragAcc); + using mfma = mfma_type_selector; + mfma::template run<>(fragA, fragB, fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) { @@ -805,29 +826,34 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; - using ScaleFragT = int32_t; + using AScaleFragT = vector_type::type; + using BScaleFragT = vector_type::type; // Create frags auto fragA = AFragT{}; auto fragB = BFragT{}; auto fragC = CFragT{}; auto fragAcc = AccumFragT{0}; - auto fragXa = ScaleFragT{0}; - auto fragXb = ScaleFragT{0}; + auto fragXa = AScaleFragT{}; + auto fragXb = BScaleFragT{}; // Load the inputs. // A = col major, BLOCK_M x BLOCK_K - fragA = load_mx_A_row_major( + fragA = load_mx_A_row_major( a, xa, fragXa); // B = col major, BLOCK_K x BLOCK_N - fragB = load_mx_B_col_major( + fragB = load_mx_B_col_major( b, xb, fragXb); // Scaled Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}( - fragA, fragXa, fragB, fragXb, fragAcc); + using mfma = mfma_scale_type_selector; + mfma::template run<>(fragA, + fragXa.template AsType(), + fragB, + fragXb.template AsType(), + fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) { From 94d47b1680eaafacca142f2498fb94d08a5b66d3 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Wed, 16 Apr 2025 09:21:04 +0800 Subject: [PATCH 045/443] fmha hdim256 vectorize improve (#2086) For hdim 256, will not have vectorized buffer load when seqlen % 256 != 0 and hdim % 256 = 0; this commit tries to solve this condition. --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 10a6e5c1d7..3634810b37 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -445,6 +445,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) From c5975529bb016318ae135431d61761b885f0f5b9 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 16 Apr 2025 10:53:21 +0800 Subject: [PATCH 046/443] add preshuffle gemm fp16 (#2036) * add preshuffle gemm fp16 * clang format and test ok * Update gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp remove useless comments in example * Update gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp remove 2 --------- Co-authored-by: coderfeli --- .../65_gemm_multiply_multiply/CMakeLists.txt | 1 + ...multiply_multiply_xdl_fp16_bpreshuffle.cpp | 371 ++++++++++++++++++ 2 files changed, 372 insertions(+) create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 95fd8bace8..deca85ae64 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,6 +1,7 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..69803c7eeb --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(F16); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 32, 128, 128, + 8, 8, + 32, 32, + 1, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, false, 1}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} From eaf1f0bf3b8fc015971be2e300e82abdd97ccfed Mon Sep 17 00:00:00 2001 From: "BingYuan.Zhou" Date: Wed, 16 Apr 2025 16:51:17 +0800 Subject: [PATCH 047/443] [flatmm] implement basic fp16 flatmm (#2089) * [flatmm] implement basic fp16 flatmm * fix CI build fail --------- Co-authored-by: root Co-authored-by: solin --- example/ck_tile/18_flatmm/CMakeLists.txt | 7 + example/ck_tile/18_flatmm/README.md | 35 ++ example/ck_tile/18_flatmm/flatmm_basic.cpp | 102 ++++ example/ck_tile/18_flatmm/flatmm_basic.hpp | 100 ++++ .../ck_tile/18_flatmm/run_flatmm_example.inc | 281 ++++++++++ .../18_flatmm/script/smoke_test_basic.sh | 34 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/ops/flatmm.hpp | 6 + .../block_flatmm_asmem_bsmem_creg_v1.hpp | 187 +++++++ ...atmm_asmem_bsmem_creg_v1_custom_policy.hpp | 38 ++ .../ops/flatmm/kernel/flatmm_kernel.hpp | 496 ++++++++++++++++++ .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 208 ++++++++ ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 265 ++++++++++ .../ops/flatmm/pipeline/tile_flatmm_shape.hpp | 43 ++ 14 files changed, 1803 insertions(+) create mode 100644 example/ck_tile/18_flatmm/CMakeLists.txt create mode 100644 example/ck_tile/18_flatmm/README.md create mode 100644 example/ck_tile/18_flatmm/flatmm_basic.cpp create mode 100644 example/ck_tile/18_flatmm/flatmm_basic.hpp create mode 100644 example/ck_tile/18_flatmm/run_flatmm_example.inc create mode 100755 example/ck_tile/18_flatmm/script/smoke_test_basic.sh create mode 100644 include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt new file mode 100644 index 0000000000..9fbe65e3a7 --- /dev/null +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + +set(EXAMPLE_FLATMM_COMPILE_OPTIONS) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) +target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md new file mode 100644 index 0000000000..beaac785fc --- /dev/null +++ b/example/ck_tile/18_flatmm/README.md @@ -0,0 +1,35 @@ +# FLATMM Matrix Multiplication + +This folder contains example for FLATMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the flatmm calculation +make tile_example_flatmm_basic -j +``` +This will result in an executable `build/bin/tile_example_flatmm_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp new file mode 100644 index 0000000000..05d0c73b7e --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "flatmm_basic.hpp" + +template +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 2; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + using CodegenFlatmmShape = + ck_tile::TileFlatmmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +#include "run_flatmm_example.inc" + +int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp new file mode 100644 index 0000000000..355ac45ebe --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -0,0 +1,100 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc new file mode 100644 index 0000000000..864d888074 --- /dev/null +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + return t; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::FlatmmHostArgs args; + args.a_ptr = a_dev_buf.GetDeviceBuffer(); + args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_dev_buf.GetDeviceBuffer(); + + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_flatmm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + // TODO: add different init types + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + c_rslt_host.SetZero(); + + // do pre-shuffle + std::string mfma = arg_parser.get_str("prec"); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, 0); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + + invoke_flatmm(a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_dev_buf.FromDevice(c_rslt_host.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_ref_host.SetZero(); + + ck_tile::reference_gemm( + a_host, b_origin_host, c_ref_host); + const float max_accumulated_value = + *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes()); + b_origin_dev_buf.ToDevice(b_origin_host.data()); + + ck_tile::HostTensor c_gpu_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes()); + c_gpu_ref_host.SetZero(); + c_gpu_ref_dev_buf.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy( + d_A, a_dev_buf.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_origin_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); + const float max_accumulated_value = + *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_gpu_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/18_flatmm/script/smoke_test_basic.sh b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh new file mode 100755 index 0000000000..a3fc61cc31 --- /dev/null +++ b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh @@ -0,0 +1,34 @@ +#!/bin/bash +EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 128 4096; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "bf16" +run_tests "fp16" + +set +x diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed35..88efe0d8d9 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 82f6d48eda..1714789e63 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -3,10 +3,16 @@ #pragma once +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" +#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..935eb2c028 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockFlatmmASmemBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindow& a_block_window, + const BFlatBlockWindow& b_flat_block_window) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp; + constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // construct Bflat-warp-window + auto b_flat_warp_windows_tmp = b_flat_block_window; + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_warp_windows; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp; + + move_tile_window(b_flat_warp_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + }); + }); + + // auto b_warp_windows = b_origin_warp_windows; + auto b_warp_windows = b_flat_warp_windows; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BFlatBlockWindow& b_flat_block_window) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..d5b062a1b3 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockFlatmmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp new file mode 100644 index 0000000000..eb45e6c0bd --- /dev/null +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -0,0 +1,496 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +struct FlatmmProblem +{ + CK_TILE_HOST FlatmmProblem() = default; + CK_TILE_HOST FlatmmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct FlatmmHostArgs : public FlatmmProblem +{ + CK_TILE_HOST FlatmmHostArgs() = default; + CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, + const void* b_shuffle_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_shuffle_ptr(b_shuffle_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t k_batch; +}; + +template +struct FlatmmKernel +{ + using TilePartitioner = remove_cvref_t; + using FlatmmPipeline = remove_cvref_t; + using BlockGemmShape = + remove_cvref_t; // TileFlatmmShape + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, FlatmmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + struct FlatmmKernelArgs + { + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t k_batch; + }; + + CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) + { + return FlatmmKernelArgs{hostArgs.a_ptr, + hostArgs.b_shuffle_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * kargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = kargs.K - KRead * (kargs.k_batch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k / + BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + const auto& b_flat_tensor_view = [&]() { + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_flat_tensor_view = views.at(I1); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_flat_pad_view = views.at(I1); + const auto& c_pad_view = views.at(I2); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_flat_block_window = + make_tile_window(b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(idxN)), 0}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_flat_block_window, c_block_window); + } + + template + CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + void* smem_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews(a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = FlatmmPipeline{}.template operator()( + a_block_window, b_flat_block_window, num_loop, smem_ptr); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr); + } + + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const + { + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_flat_ptr = static_cast(kargs.b_shuffle_ptr) + + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if(kargs.k_batch == 1) + { + RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + // Do not compile in case where we have unsupported + // VectorSizeC & data type configuration. + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunFlatmm( + a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..3d08c7a788 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +template +struct FlatmmPipelineAGmemBGmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + // Acc register tile + auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){}; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDistribution()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + block_sync_lds(); + + // GEMM i + block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + iCounter--; + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp new file mode 100644 index 0000000000..d1aac07d54 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + return Problem::VectorLoadSize; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using BDataType = remove_cvref_t; + + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t KBPerLoad = + Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(TileShape::idxN); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<1, 2>>, // which direction + tuple, sequence<2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + using BlockFlatmmPolicy = + BlockFlatmmASmemBSmemCRegV1CustomPolicy; + return BlockFlatmmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp new file mode 100644 index 0000000000..551d390ec6 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct TileFlatmmShape +{ + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto idxM = number<0>{}; + static constexpr auto idxN = number<1>{}; + static constexpr auto idxK = number<2>{}; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(idxM); + static constexpr index_t kN = BlockTile::at(idxN); + static constexpr index_t kK = BlockTile::at(idxK); + + static constexpr index_t flatNPerWarp = BlockWarps::at(idxN); + static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN); + static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_flatmm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(idxM), BlockWarps::at(idxN), BlockWarps::at(idxK)), + concat('x', (WarpTile::at(idxM)), WarpTile::at(idxN), WarpTile::at(idxK))); + // clang-format on + } +}; + +} // namespace ck_tile From 7c32652e03d9a2015f5ab04c5193723869e2525e Mon Sep 17 00:00:00 2001 From: aledudek Date: Wed, 16 Apr 2025 11:00:55 +0200 Subject: [PATCH 048/443] Add grouped conv fwd 3d GKCYX instances for f32, f16, bf16 (#2069) * Part1 * Add grouped conv fwd 3d GKCYX instances for f32, f16, bf16 * Add missing coma * Add missing cpp instance files * Fix 3d layout * Add missing closing bracket * Add missing comp x2 and part2 instances * Fix typo in instance name * fix * Fix --------- Co-authored-by: Bartlomiej Kocot --- .../gpu/grouped_convolution_forward.hpp | 64 ++++++++++- .../grouped_convolution_forward_comp_xdl.inc | 105 ++++++++++++++++++ ...uped_convolution_forward_mem_inter_xdl.inc | 49 ++++++++ ...uped_convolution_forward_mem_intra_xdl.inc | 49 ++++++++ .../gpu/grouped_convolution_forward_xdl.inc | 49 ++++++++ ..._convolution_forward_xdl_merged_groups.inc | 49 ++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 19 ++++ ...hw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp | 43 +++++++ ...gcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp | 54 +++++++++ ...gkczyx_ngkdhw_bf16_comp_part2_instance.cpp | 45 ++++++++ ...dhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp | 43 +++++++ ...ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp | 54 +++++++++ ..._gkczyx_ngkdhw_f16_comp_part2_instance.cpp | 45 ++++++++ ...ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp | 54 +++++++++ ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 53 +++++++++ ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 53 +++++++++ ..._xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp | 53 +++++++++ ..._gkczyx_ngkdhw_bf16_mem_inter_instance.cpp | 55 +++++++++ ..._gkczyx_ngkdhw_bf16_mem_intra_instance.cpp | 55 +++++++++ ...w_gkczyx_ngkdhw_f16_mem_inter_instance.cpp | 55 +++++++++ ...w_gkczyx_ngkdhw_f16_mem_intra_instance.cpp | 55 +++++++++ ...w_gkczyx_ngkdhw_f32_mem_inter_instance.cpp | 55 +++++++++ ...w_gkczyx_ngkdhw_f32_mem_intra_instance.cpp | 55 +++++++++ ...ups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 47 ++++++++ ...oups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 47 ++++++++ ...oups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp | 47 ++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 27 ++++- .../test_grouped_convnd_fwd.cpp | 4 +- 28 files changed, 1377 insertions(+), 6 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 0b7df6ecfb..638a3f98a3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -523,7 +523,69 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + op_ptrs); + } #endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + +#endif // CK_USE_XDL #ifdef CK_USE_WMMA if constexpr(NumDimSpatial == 2 && is_same_v && @@ -639,7 +701,7 @@ struct DeviceOperationInstanceFactory>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index 3900c7a0fb..00351ceefd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -171,6 +171,55 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index b7815f5023..bd44116057 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -171,6 +171,55 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b934b9aef6..d3624b0fd9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -517,6 +517,55 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( F8>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc index 966b883301..9f54c4b633 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -178,6 +178,55 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 1e572f9ceb..7b9ccf6609 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -8,6 +8,9 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -16,18 +19,34 @@ set(GROUPED_CONV3D_FWD xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..3e1a2dd48b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp new file mode 100644 index 0000000000..43241454a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..85a1c9137c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instances( + std::vector>>&) +{ + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); +#endif + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp new file mode 100644 index 0000000000..9b8bf4fa42 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp new file mode 100644 index 0000000000..d02d9f6778 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp new file mode 100644 index 0000000000..eaac75ee9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instances( + std::vector>>&) +{ + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); +#endif + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp new file mode 100644 index 0000000000..696ea7f34e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..060eebebc1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..85b088f416 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp new file mode 100644 index 0000000000..2b3e596355 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..fac3098341 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..f3eccc7dc8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp new file mode 100644 index 0000000000..abea0bea81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp new file mode 100644 index 0000000000..ba5d9fb1de --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp new file mode 100644 index 0000000000..5a2c4a0d5b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp new file mode 100644 index 0000000000..701b8eb4a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..71bde2faa5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..2e71b76256 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp new file mode 100644 index 0000000000..8a53dea612 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 9ee05d1304..a7714b4c73 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -114,17 +114,19 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using GKZYXC = ck::tensor_layout::convolution::GKZYXC; // using GKCX = ck::tensor_layout::convolution::GKXC; - using GKCYX = ck::tensor_layout::convolution::GKCYX; - // using GKCZYX = ck::tensor_layout::convolution::GKZYXC; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; using GNWK = ck::tensor_layout::convolution::GNWK; using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK; // - using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; - using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; // using NWGC = ck::tensor_layout::convolution::NWGC; @@ -366,6 +368,23 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); } } + // NGCDHW_GKCZYX_NGKDHW + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 43b77641d1..1cf91df52c 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -77,7 +77,9 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple, std::tuple, - std::tuple>; + std::tuple, + std::tuple, + std::tuple>; template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd From 3bb62f16cd023095dac9467351253861b9d92555 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 16 Apr 2025 12:10:15 -0700 Subject: [PATCH 049/443] Upgrade default docker to Ubuntu24.04 (#2090) * upgrade docker to Ubuntu24.04 * add break-system-packages flag to pip install * fix dockerfile --- Dockerfile | 14 +++++--------- Dockerfile.compiler | 2 +- Jenkinsfile | 8 ++++---- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2a8fb707c9..f77c685000 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=6.4 ARG compiler_version="" @@ -14,8 +14,8 @@ RUN set -xe && \ curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg RUN if [ "$ROCMVERSION" != "6.5" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60400-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60400-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ @@ -44,7 +44,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- iputils-ping \ jq \ libelf-dev \ - libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ @@ -73,10 +72,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- # Remove unnecessary rocm components that take a lot of space apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt -# Update the cmake to version 3.27.5 -RUN pip install --upgrade cmake==3.27.5 && \ #Install latest ccache - git clone https://github.com/ccache/ccache.git && \ +RUN git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ @@ -97,8 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results - pip3 install --upgrade pip && \ - pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/Dockerfile.compiler b/Dockerfile.compiler index f4aa12f356..7534910681 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.4" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index e6256fc3d8..3d7019bd1f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -40,10 +40,10 @@ def getBaseDockerImageName(){ else{ def ROCM_numeric = "${params.ROCMVERSION}" as float if ( ROCM_numeric < 6.5 ){ - img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}" + img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}" + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm${params.ROCMVERSION}" } } return img @@ -535,7 +535,7 @@ def Build_CK(Map conf=[:]){ if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ echo "Run inductor codegen tests" sh """ - pip install --verbose . + pip install --break-system-packages --verbose . pytest python/test/test_gen_instances.py """ } @@ -745,7 +745,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true; + 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false From da54464cce95c2f0334676ce24b863eed202d873 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:25:02 -0600 Subject: [PATCH 050/443] MX GEMM - Add MX BF8 example (#2071) * Add MX GEMM example for MX BF8 * Verified MX FP8 with 16x16x128 scale builtin * Verify MX BF8 GEMM with BF16 output --- example/67_gemm_microscaling/CMakeLists.txt | 3 + example/67_gemm_microscaling/gemm_mx_bf8.cpp | 98 +++++++++++++++++++ .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 3 + .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 6 ++ include/ck/utility/amd_xdlops.hpp | 29 ++++++ 5 files changed, 139 insertions(+) create mode 100644 example/67_gemm_microscaling/gemm_mx_bf8.cpp diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 93770684df..34125465a9 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -3,3 +3,6 @@ add_custom_target(example_gemm_mx) add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) +add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) + diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp new file mode 100644 index 0000000000..8e341fb591 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 128, // BlockSize: Thread block size + 128, // MPerBlock + 16, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 1, // NXdlPerWave + S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 8a370304c6..62bc2c4499 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -699,6 +699,9 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX && is_same_v, + "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType"); + return true; } diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 529a1a1729..08c4e4ba6e 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1141,6 +1141,12 @@ struct MfmaSelector return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a54a181bf1..a8c3baa31b 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -588,6 +588,35 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_b; ignore = scale_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; #endif } }; From 213b203a3c4409cc0906cf13ecbc3a09092f67b2 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:56:00 -0600 Subject: [PATCH 051/443] MX GEMM - Parameterized Test Template (#2088) * Tests for MX FP8 GEMM * Improve documentation --- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 16 +- .../tensor_operation_instance/gpu/gemm_mx.hpp | 111 ++++ .../gpu/CMakeLists.txt | 13 + .../gpu/gemm_mx/CMakeLists.txt | 14 + ...device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 63 +++ ...l_f8_f8_bf16_mk_nk_mn_default_instance.cpp | 32 ++ .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 63 +++ ...dl_f8_f8_f16_mk_nk_mn_default_instance.cpp | 32 ++ test/CMakeLists.txt | 1 + test/gemm_mx/CMakeLists.txt | 4 + test/gemm_mx/test_gemm_mx.cpp | 108 ++++ test/gemm_mx/test_gemm_mx_util.hpp | 498 ++++++++++++++++++ 12 files changed, 948 insertions(+), 7 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp create mode 100644 test/gemm_mx/CMakeLists.txt create mode 100644 test/gemm_mx/test_gemm_mx.cpp create mode 100644 test/gemm_mx/test_gemm_mx_util.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 62bc2c4499..c37af49387 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -22,6 +22,7 @@ namespace ck { namespace tensor_operation { namespace device { +// clang-format off /** * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types * @@ -31,8 +32,8 @@ namespace device { * Assumptions: * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification * - Each scale applies to ScaleBlockSize elements in K direction - * - A scale matrix is row-major - * - B scale matrix is column-major + * - A scale matrix is a row-major + * - B scale matrix is a column-major * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the * exponent will be interpreted as conventional biased Float32 exponent (E8M0) * @@ -72,10 +73,10 @@ namespace device { * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ - * // MFMA accumulation for multirate instructions - * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){ - * for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){ - * // MFMA instruction + * // MFMA accumulation + * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){ + * // MFMA instruction + * for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){ * for(int m = mw; m < mw + MPerXDL; m++){ * for(int n = nw; n < nw + NPerXDL; n++){ * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ @@ -96,6 +97,7 @@ namespace device { * \endcode * */ +// clang-format on template +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMX> +{ + using DeviceOp = DeviceGemmMX; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 2542dd236b..70e54962ed 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -60,6 +60,13 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build MX instances if gfx950 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing MX instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") @@ -100,6 +107,8 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source MATCHES "mha") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(source MATCHES "_mx") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) @@ -234,6 +243,10 @@ FOREACH(subdir_path ${dir_list}) if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY MX_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx950")) + message("Found only MX instances, but gfx950 is not on the targets list. Skipping.") + set(add_inst 0) endif() if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..a166fc4ce4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -0,0 +1,14 @@ +# ONLY MX_KERNELS +set(GEMM_MX_INSTANCES) + +list(APPEND GEMM_MX_INSTANCES + device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp + ) + + +set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + + +add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..1e979f69ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + +//Require verification + //DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..05914e06b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..0ca4f2a3ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + + //Require verification + //DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..f4e59cf92d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 18611d8052..72c51823be 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -279,6 +279,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9 endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_subdirectory(mx_mfma_op) + add_subdirectory(gemm_mx) endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) diff --git a/test/gemm_mx/CMakeLists.txt b/test/gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..71a0a98f2d --- /dev/null +++ b/test/gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_mx test_gemm_mx.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_mx PRIVATE utility device_gemm_mx_instance) + endif() diff --git a/test/gemm_mx/test_gemm_mx.cpp b/test/gemm_mx/test_gemm_mx.cpp new file mode 100644 index 0000000000..6e1957e60a --- /dev/null +++ b/test/gemm_mx/test_gemm_mx.cpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "test_gemm_mx_util.hpp" + +using E8M0 = ck::e8m0_bexp_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F6 = ck::f6_t; +using BF6 = ck::bf6_t; +using F4 = ck::f4_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmMX_MK_NK + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< F8, F8, F16, ck::Number<32> >, + std::tuple< F8, F8, BF16, ck::Number<32> > +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmMX_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, Large) +{ + std::vector Ms{4096}; + constexpr int N = 3840; + constexpr int K = 4096; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_mx/test_gemm_mx_util.hpp b/test/gemm_mx/test_gemm_mx_util.hpp new file mode 100644 index 0000000000..3bca4ceded --- /dev/null +++ b/test/gemm_mx/test_gemm_mx_util.hpp @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/number.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +namespace ck { +namespace test { + +namespace { +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +template +bool profile_gemm_mx_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + using ScaleDataType = e8m0_bexp_t; + using AScaleLayout = Row; + using BScaleLayout = Col; + + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale(f_host_tensor_descriptor( + M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A + Tensor b_k_n_scale(f_host_tensor_descriptor( + K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::size_t total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: // Initializations for development and debugging + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); + if(do_log) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {0.5}" << std::endl; + std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + case 1: + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + b_k_n.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + + break; + + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1] + + b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + break; + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_log > 0) + std::cout << "Device memory allocation..." << std::endl; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(ScaleDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(ScaleDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(do_log > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + + if(do_log > 0) + std::cout << "Done." << std::endl; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMXGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + std::optional best_op_object_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0 + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_log) + { + + if(init_method == 0) + { + auto expected = static_cast(K); + auto computed = type_convert(c_m_n_device_result(0, 12)); + + pass = pass & (std::abs(expected - computed) <= 0.0f); + std::cout << "\nExpected vs Computed: " << expected << " vs " + << computed << ((pass) ? " (PASSED!)" : " (FAILED!)") + << std::endl + << std::endl; + } + else + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "a_scale : ", a_m_k_scale.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b_scale: ", b_k_n_scale.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + + std::string op_name = op_ptr->GetTypeString(); + std::optional op_obj_name = op_ptr->GetObjectName(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + + // scaling of partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(ScaleDataType) * (M * K + K * N) / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops && ave_time > 1e-10) + { + best_op_name = op_name; + best_op_object_name = op_obj_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + if(best_op_object_name) + std::cout << best_op_object_name.value() << std::endl; + + return pass; +} + +template +class TestGemmMX : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + using ScaleType = e8m0_bexp_t; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = float; + + public: + static constexpr index_t ScaleBlockSize = std::tuple_element_t<5, Tuple>{}; + static constexpr bool verify_ = true; + static constexpr int init_method_ = 2; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::test::profile_gemm_mx_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck From bcf5bb41be976d948b504f3d66c29e5baa82618a Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Fri, 18 Apr 2025 10:45:49 +0800 Subject: [PATCH 052/443] enable do top k weights in moe stage1 gemm (#2094) * add switch for mul topk weights * fix bf16/f16 bugs * complete --- .../moe_gemm1_xdl_fp8.cpp | 64 +++++++++++-- .../moe_gemm1_xdl_pk_i4.cpp | 63 +++++++++++-- .../moe_gemm2_xdl_fp8.cpp | 8 +- .../moe_gemm2_xdl_pk_i4.cpp | 8 +- .../gpu/device/impl/device_moe_gemm.hpp | 8 +- .../gpu/grid/gridwise_moe_gemm.hpp | 93 +++++++++++-------- .../cpu/reference_moe_gemm.hpp | 15 ++- .../cpu/reference_moe_gemm2.hpp | 12 ++- 8 files changed, 203 insertions(+), 68 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 66825edcf9..f594080755 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -39,14 +39,16 @@ using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -83,9 +85,36 @@ struct MulABScaleSilu } }; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for real kernel use + // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. + // tofix:felix + (void)d2; + e = ck::type_convert(c * d1 * d0); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true // using DsLayout = DsLayoutGate; // using DsDataType = DsDataTypeGate; -using CDEElementOp = MulABScale; +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false // using CDEElementOp = MulABScaleSiluMulGate; @@ -133,11 +162,13 @@ static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = true; +static constexpr bool MulRoutedWeight = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -157,8 +188,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; // clang-format on @@ -224,7 +255,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{1, 0}; + constexpr auto StrideDs = std::array{0, 0, 0}; ck::index_t KBatch = 1; @@ -266,6 +297,7 @@ int main(int argc, char* argv[]) Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -273,6 +305,7 @@ int main(int argc, char* argv[]) std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; switch(init_method) @@ -283,24 +316,28 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -310,6 +347,7 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); // a0_t_k.savetxt("a.txt"); // d0_t_n.savetxt("d0_t_n.txt", "int"); @@ -320,6 +358,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -342,7 +381,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -392,10 +432,12 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -406,6 +448,7 @@ int main(int argc, char* argv[]) a0_t_k, b0_e_n_k, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -428,7 +471,8 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + 1.f); } } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index a25d1b5fa3..fb8a8b9826 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -39,14 +39,15 @@ using AccDataType = F32; using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -91,7 +92,39 @@ struct MulABScaleSilu } }; -using CDEElementOp = MulABScale; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d2; + +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d1 * d0 * 16); +#else + e = ck::type_convert(c * d1 * d0); +#endif + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d0 * d1 * d2 * 16); +#else + e = ck::type_convert(c * d0 * d1 * d2); +#endif + } +}; + +using CDEElementOp = MulABScaleExpertWeight; #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -164,6 +197,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< #else static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t Nswizzle = false; +static constexpr bool MulRoutedWeight = false; // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, @@ -175,8 +209,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; // clang-format on #endif @@ -265,6 +299,7 @@ int main(int argc, char* argv[]) Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); @@ -283,18 +318,21 @@ int main(int argc, char* argv[]) b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -304,6 +342,7 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); @@ -312,6 +351,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -424,7 +464,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -480,10 +521,12 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -494,6 +537,7 @@ int main(int argc, char* argv[]) a0_t_k, b0_e_n_k, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -516,7 +560,8 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + 1.f); } } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 0d12441016..04f10b53ae 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -135,6 +135,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = false; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -164,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -409,7 +410,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 8c2c70b4a1..ba4e40151f 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -138,6 +138,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, @@ -149,7 +150,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -455,7 +456,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index f3fc1aaa9f..03db4bdd41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,6 +67,7 @@ template ; RunKernel(kernel); } @@ -280,6 +282,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -295,6 +298,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -305,6 +309,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -325,6 +330,7 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 1924c27b2b..a2d1114bbe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,6 +31,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -44,19 +45,22 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run(karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -67,6 +71,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -81,21 +86,23 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm:: - template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds(karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -1134,8 +1141,9 @@ struct GridwiseMoeGemm template + bool IsInputGemm = true, + bool MulRoutedWeight = true, + TailNumber TailNum = TailNumber::Odd> __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1492,7 +1500,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1579,10 +1587,13 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else + if constexpr(MulRoutedWeight) { const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; + if constexpr(sizeof(ADataType) < 2) + weight = p_sorted_weights_2[c_token_pos + m0] * weight; + else + weight = p_sorted_weights_2[c_token_pos + m0]; } scatter_offsets(m0) = token_offset * problem.N; scatter_weights(m0) = weight; @@ -1632,8 +1643,9 @@ struct GridwiseMoeGemm template + bool IsInputGemm = true, + bool MulRoutedWeight = true, + TailNumber TailNum = TailNumber::Odd> __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1998,7 +2010,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2086,10 +2098,13 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else + if constexpr(MulRoutedWeight) { const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; + if constexpr(sizeof(ADataType) < 2) + weight = p_sorted_weights_2[c_token_pos + m0] * weight; + else + weight = p_sorted_weights_2[c_token_pos + m0]; } scatter_offsets(m0) = token_offset * problem.N; scatter_weights(m0) = weight; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index af735925ed..72c9dc86ac 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,10 +18,12 @@ namespace host { template struct ReferenceMoeGemm : public device::BaseOperator @@ -36,6 +38,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -46,6 +49,7 @@ struct ReferenceMoeGemm : public device::BaseOperator a_t_k_{a_t_k}, b_e_n_k_{b_e_n_k}, c_t_k_n_{c_t_k_n}, + d2_{d2}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} @@ -59,6 +63,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k_; const Tensor& b_e_n_k_; Tensor& c_t_k_n_; + const Tensor& d2_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -81,6 +86,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + D2DataType v_topk_w = arg.d2_(m, 0); // expert if(t < token_cnt) { for(int k = 0; k < K; ++k) @@ -128,6 +134,11 @@ struct ReferenceMoeGemm : public device::BaseOperator } CDataType v_c{0}; + if constexpr(MulRoutedWeight) + { + v_acc *= v_topk_w; + } + arg.c_element_op_(v_c, v_acc); arg.c_t_k_n_(t, topk_id, n) = v_c; @@ -164,6 +175,7 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& a_t_k, const Tensor& b_e_n_k, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -175,6 +187,7 @@ struct ReferenceMoeGemm : public device::BaseOperator a_t_k, b_e_n_k, c_t_k_n, + d2, a_element_op, b_element_op, c_element_op}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 1e8a086bc4..fb5c71e30a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,6 +25,7 @@ template struct ReferenceMoeGemm2 : public device::BaseOperator @@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator CDataType v_c{0}; D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d1 = arg.d1_(e, n); // b - arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + if constexpr(MulRoutedWeight) + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + } + else + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f); + } arg.c_t_n_(t, n) += v_c; } }; From c318ec0778f0b9db90618ac51185ff6f9dfab0e1 Mon Sep 17 00:00:00 2001 From: solin Date: Fri, 18 Apr 2025 09:15:27 +0000 Subject: [PATCH 053/443] fix CI build fail --- .../ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 3d08c7a788..611aff318f 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" namespace ck_tile { From 7cadf187e28693eb211c9cfb76d72ba0d6fb28b8 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 21 Apr 2025 08:39:45 -0700 Subject: [PATCH 054/443] multi instance generation for CkTileEngine (#2080) * Add support for multi-instance verification, print detail for each instance, documentation fix * clang formatted * Added Readme file * updated readme * Addressing review comments * clang formatted * Updated ReadMe and GPU reference code * simplified dispatch kernel code * indentation --- tile_engine/ops/gemm/README.md | 51 ++++++ .../gemm/configs/instance_combination.json | 2 +- tile_engine/ops/gemm/gemm_host_api.cpp | 79 +++++----- tile_engine/ops/gemm/gemm_host_api.hpp | 146 +++++++----------- tile_engine/ops/gemm/gemm_instance_builder.py | 64 ++++++-- 5 files changed, 202 insertions(+), 140 deletions(-) create mode 100644 tile_engine/ops/gemm/README.md diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md new file mode 100644 index 0000000000..495232f19b --- /dev/null +++ b/tile_engine/ops/gemm/README.md @@ -0,0 +1,51 @@ +# GEMM Matrix Multiplication + +Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`. + +# Kernel Configurations + +User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json` + +## Build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# To generate the executable +make tile_engine_gemm -j +``` +`tile_engine_gemm` will be located in the `./bin/` directory. + +## tile_engine_gemm inputs +``` + + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -split_k SplitK value (default:1) + -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -warmup Number of iterations before benchmark the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) + -pipeline possible values are: compv3, compv4, mem (default:compv3) + -scheduler possible values are: intrawave, interwave (default:intrawave) + -epilogue possible values are: cshuffle, default (default:cshuffle) + -pad_m Pad in m direction - true/false (default:false) + -pad_n Pad in n direction - true/false (default:false) + -pad_k Pad in k direction - true/false (default:false) + +Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json +``` + +## Example + +Below example will run gemm kernel with default dimensions of matrices, for compv3 pipeline, intrawave scheduler and default epilogue with all possible tile sizes mentioned in Config file. + +``` +./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default +``` diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index e21197d1de..e23df11500 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -19,7 +19,7 @@ "values": [256] }, "tile_k": { - "values": [64] + "values": [64, 32] }, "warp_m": { "values": [2] diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp index 508f634920..3cef425a51 100644 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -6,11 +6,16 @@ #include "gemm_dispatcher.hpp" #include "gemm_host_api.hpp" -float gemm_kernel_launch(KernelTraits& trait, - ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) +void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, + KernelTraits& trait, + ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) { - return GemmDispatcher::dispatch(trait, args, s); + return GemmDispatcher::dispatch( + c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s); } template -bool run(const ck_tile::ArgParser& arg_parser) +void run(const ck_tile::ArgParser& arg_parser) { const ALayout a_layout = ALayout{}; const BLayout b_layout = BLayout{}; - // const CLayout c_layout = CLayout{}; ck_tile::index_t kbatch = arg_parser.get_int("split_k"); ck_tile::index_t M = arg_parser.get_int("m"); @@ -113,43 +117,47 @@ bool run(const ck_tile::ArgParser& arg_parser) trait.kPadN = arg_parser.get_bool("pad_n"); trait.kPadK = arg_parser.get_bool("pad_k"); - float ave_time = gemm_kernel_launch( - trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + << " C Type = " << DataTypeTraits::name << std::endl; + + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - bool pass = true; if(verify) { - pass = gemm_verify( - verify, - a_m_k, - b_k_n, - c_m_n_dev_result, - a_m_k_dev_buf, - b_k_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch); + gemm_host_reference(verify, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C); } - return pass; + + gemm_kernel_launch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + trait, + gemm_args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + return; } int main(int argc, char* argv[]) @@ -159,7 +167,8 @@ int main(int argc, char* argv[]) auto [result, parser] = create_args(argc, argv); if(!result) return EXIT_FAILURE; - return run(parser); + run(parser); + return 0; } catch(const std::exception& e) { diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp index 375f808966..c1e1e1dc4f 100644 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include @@ -54,24 +57,21 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; -/** - * @brief trait for GEMM kernel - * @param pipeline: pipeline name - * @param scheduler: scheduler name - * @param epilogue: epilogue name - * @param kPadM: padding for M dimension - * @param kPadN: padding for N dimension - * @param kPadK: padding for K dimension - * - */ - +/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a +/// specific kernel instance based on the provided settings. struct KernelTraits { + /// @brief The name of the pipeline. std::string pipeline; + /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). std::string scheduler; + /// @brief The name of the epilogue (e.g., "cshuffle", "default"). std::string epilogue; + /// @brief Indicates whether padding is applied to the M dimension. bool kPadM; + /// @brief Indicates whether padding is applied to the N dimension. bool kPadN; + /// @brief Indicates whether padding is applied to the K dimension. bool kPadK; }; @@ -184,11 +184,28 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -/** - * @brief Function to verify the kernel output with reference implementation on CPU/GPU - * - */ +/// @brief Function to compare the results of the device and host computations +void compare(ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU template -bool gemm_verify(int verify, - ck_tile::HostTensor& a_m_k, - ck_tile::HostTensor& b_k_n, - ck_tile::HostTensor& c_m_n_dev_result, - ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - ck_tile::index_t kbatch) +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) { - bool pass = true; if(verify == 1) { - ck_tile::HostTensor c_m_n_host_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_host_ref.SetZero(); + c_m_n_host_result.SetZero(); ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_ref); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; - std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + a_m_k, b_k_n, c_m_n_host_result); } else if(verify == 2) { @@ -241,29 +240,14 @@ bool gemm_verify(int verify, // Restore input for B for gpu reference b_k_n_dev_buf.ToDevice(b_k_n.data()); } - ck_tile::HostTensor c_m_n_gpu_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); - c_m_n_gpu_ref.SetZero(); + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ADataType* d_A; - BDataType* d_B; - CDataType* d_C; - - ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); - ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); - ck_tile::hip_check_error( - hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); - - ck_tile::hip_check_error(hipMemcpy(d_A, - a_m_k_dev_buf.GetDeviceBuffer(), - a_m_k.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); - ck_tile::hip_check_error(hipMemcpy(d_B, - b_k_n_dev_buf.GetDeviceBuffer(), - b_k_n.get_element_space_size_in_bytes(), - hipMemcpyHostToDevice)); + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), - d_C, - c_m_n_dev_result.get_element_space_size_in_bytes(), - hipMemcpyDeviceToHost)); - - ck_tile::hip_check_error(hipFree(d_A)); - ck_tile::hip_check_error(hipFree(d_B)); - ck_tile::hip_check_error(hipFree(d_C)); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); - const float max_accumulated_value = - *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_gpu_ref, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; - std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); } - return pass; } diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index e449dff94d..cfefd38cd2 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -447,6 +447,17 @@ struct GemmKernel {{ return ave_time; }} + static std::string get_name() {{ + return std::string("GemmKernel> kernel_map; + std::function& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map; return kernel_map; } @@ -499,9 +513,12 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [](ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) {{ - std::vector results;""" + content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s) {{ + """ for tile in tile_params: # Check if we have valid tile/warp combinations # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m @@ -509,21 +526,46 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - //we can have multiple tiles config for the one kernel_trait - return {group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>::launch(args, s);""" - content += """ - };\n""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" + content += f""" + }};\n""" content += """ } - - static float dispatch(const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + template + static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + { + float avg_time = Kernel::launch(args, s); + std::string description = Kernel::get_name(); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::size_t flop = std::size_t(2) * args.M * args.N * args.K; + std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N; + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_byte / 1.E6 / avg_time; + + std::cout << "Performance for " << description << " : " << avg_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + if(verify) + compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, const ck_tile::stream_config& s) { init(); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(gemm_args, s); //Running single instance + return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s); } throw std::runtime_error("No suitable kernel found: " + key); } From ce6175953804dceec37cb1f19e4b5194b3ed9a24 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 21 Apr 2025 08:48:22 -0700 Subject: [PATCH 055/443] fix daily gfx942 build (#2106) --- Jenkinsfile | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 3d7019bd1f..f8043ba918 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -291,11 +291,6 @@ def cmake_build(Map conf=[:]){ setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") } - else if (setup_args.contains("gfx908;gfx90a;gfx942")){ - //limit the number of build threads when building for multiple gfx9 targets - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j32 ${config_targets}") - } else{ setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") @@ -604,7 +599,7 @@ def Build_CK(Map conf=[:]){ stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" } else if ( arch_type == 6 ){ - // run standard tests on gfx908 + // run basic tests on gfx908 echo "Run performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" archiveArtifacts "perf_onnx_gemm_gfx908.log" @@ -1115,11 +1110,11 @@ pipeline { agent{ label rocmnode("gfx942") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DGPU_TARGETS="gfx90a;gfx942" \ -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DGPU_TARGETS="gfx90a;gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } From a738e43445f9f82227220922fcd2d683cc9ef626 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 21 Apr 2025 10:21:35 -0700 Subject: [PATCH 056/443] MFMA 16x16x32fp8 (#2103) * add mfma_16x16x32_fp8 * clang format code * Finished the fix for gemm basic * clang foramt * rebuild CI * recover gemm.hpp * add MFMA 16*16*32bf8 --------- Co-authored-by: solin --- .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 + .../ops/gemm/pipeline/tile_gemm_traits.hpp | 3 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 14 ++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 167 +++++++++++++++++- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 4 + 5 files changed, 188 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index cba3677332..0b38e7789e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -32,6 +32,8 @@ struct GemmPipelineProblemBase static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr bool kPadM = Traits::kPadM; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 0dae2eeca5..a31004b425 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -26,7 +26,8 @@ struct TileGemmTraits using BLayout = BLayout_; using CLayout = CLayout_; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; }; template >>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 21a865e792..64c7543ffe 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -623,6 +623,165 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 }; // FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_116x16x32_fp8_bf8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v") + } + } + else + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_316x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { @@ -809,11 +968,17 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; - +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; + template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 6320b33598..f437ee10c5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -57,12 +57,16 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // clang-format on From b092c18da708422fb529193de40b6224446007c5 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Mon, 21 Apr 2025 11:44:07 -0700 Subject: [PATCH 057/443] MI308 fix for streamk 1-Tile floating point exception (#2101) --- .../gpu/grid/block_to_ctile_map.hpp | 67 ++++++++++++++++--- ...t_gemm_universal_streamk_ut_cases_bf16.inc | 28 -------- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 64fad1ca48..311545aad6 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1438,6 +1438,7 @@ struct BlockToCTileMap_GemmStreamK_v2 __host__ __device__ BlockToCTileMap_GemmStreamK_v2( uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) { + // total output tiles uint32_t num_tiles = math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); @@ -1445,6 +1446,9 @@ struct BlockToCTileMap_GemmStreamK_v2 uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + // Ensure grid_size is at least 1 to avoid division by zero + grid_size = math::max(grid_size, 1u); + // default to regular DP GEMM if sk blocks == 0 if(streamk_sel == 0) { @@ -1460,31 +1464,45 @@ struct BlockToCTileMap_GemmStreamK_v2 // 2-tile sk + DP GEMM else { - // check if there's enough work for DP+ stream-k bool bigEnough = num_tiles > grid_size; - // select between stream-k strategies + + // Select between stream-k strategies + // Add safety checks to prevent zero or negative values uint32_t sk_tiles = 0; if(streamk_sel == 1) // 1 tile stream-k { sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 + sk_tiles = math::max(sk_tiles, 1u); } else if(streamk_sel == 2) // 2-tile stream-k { sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 3) // 3-tile stream-k { sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 4) // 4-tile stream-k { sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } + sk_num_blocks = sk_tiles; - // remaining tiles are DP tiles + // Remaining tiles are DP tiles dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; sk_total_iters = k_iters_per_tile.get() * sk_tiles; @@ -1500,24 +1518,51 @@ struct BlockToCTileMap_GemmStreamK_v2 // => sk_blocks * m + b = sk_total_iters // => b = sk_total_iters - m * sk_blocks // NOTE: big could be zero - uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; - sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; - k_iters_per_big_block = k_iters_per_sk_block + 1; + + // Add safety check for sk_num_blocks to prevent division by zero + if(sk_num_blocks > 0) + { + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + } + else + { + // Fallback to default GEMM if no stream-k blocks + sk_num_blocks = 0; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + dp_tiles = num_tiles; + dp_num_blocks = num_tiles; + dp_start_block_idx = 0; + sk_total_iters = 0; + } dp_num_blocks = dp_tiles; dp_start_block_idx = sk_num_blocks; } n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); - // using multiple blocks for parallel reduction + // Using multiple blocks for parallel reduction reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) { - uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); - uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); - equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); - equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + // Add additional safety checks + if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = + math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + else + { + // Default safe values + equiv_tiles_big = MDiv(1); + equiv_tiles_little = MDiv(1); + } } } diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc index b6970c4233..22977866b5 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc @@ -44,34 +44,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; From 4bef60aa57c35575708a4af636f838e6cf26147d Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 21 Apr 2025 13:53:03 -0700 Subject: [PATCH 058/443] update code owner (#2113) --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 15903314f9..eb69bd7f39 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing From 0cca8fa28ff31ee7403a667deffc954bd467041f Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 22 Apr 2025 01:13:22 -0700 Subject: [PATCH 059/443] GEMM Multiply Multiply Fix (#2102) * fix the type convert and increase the BF16 conversion + the profile comment * fix the CI --- include/ck/utility/type_convert.hpp | 2 +- profiler/src/profile_gemm_multiply_multiply.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index c8127aa887..04ae046ac8 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else - return uint16_t(uint32_t{x} >> 16); + return uint16_t(static_cast(x) >> 16); #endif } diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index ad2bb77544..42192b5985 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -42,7 +42,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " "f16->f8; 7: f8->bf16, " - "comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n"); + "comp f8; 8: int8->bf16; 9: int8->f16, 10. f8->f16;)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); From 416e851584f5ec7d8b9cfc6ea73b829900b73750 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 22 Apr 2025 16:08:48 -0500 Subject: [PATCH 060/443] Temporarily disable MX FP4 device tests (#2112) --- include/ck/ck.hpp | 3 +++ test/data_type/test_mx_fp4.cpp | 2 ++ 2 files changed, 5 insertions(+) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 0c2dc799ab..83b76382bc 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -244,6 +244,9 @@ // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 +// workaround: compiler issue on gfx950 +#define CK_TEMP_DISABLE_FP4_TESTS 1 + // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 449f6fc777..7aca42567c 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -240,6 +240,7 @@ TEST(MXFP4, HostScaledConvert) EXPECT_EQ(test_size, i); } +#if !CK_TEMP_DISABLE_FP4_TESTS __global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) { test_mx_fp4_scaled_convert(N, p_test, p_completed); @@ -539,3 +540,4 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) EXPECT_EQ(N, completed); EXPECT_EQ(N, i); } +#endif // CK_TEMP_DISABLE_FP4_TESTS From 504f563f78fbf1a78d1d68fc94cdd69dfea2fb60 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Wed, 23 Apr 2025 06:52:36 +0800 Subject: [PATCH 061/443] [CK-Tile] warp-gemm support for using V_MFMA_F32_16x16x32_BF16 (#2073) * draft v_mfma_f32_16x16x32_bf16 * fix error config and add debug code. * Solve the CShuffle Problem * draft v_mfma_f32_16x16x32_bf16 * fix error config and add debug code. * Solve the CShuffle Problem * fix error while testing new command * Finished the feature of new mfma 16*16*32 * Addressed the comment --------- Co-authored-by: ThomasNing --- example/ck_tile/03_gemm/gemm_basic.cpp | 0 example/ck_tile/03_gemm/gemm_utils.hpp | 12 +- example/ck_tile/03_gemm/run_gemm_example.inc | 1 - include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 23 +++- .../warp/warp_gemm_attribute_mfma_impl.hpp | 126 ++++++++++++++++++ 5 files changed, 154 insertions(+), 8 deletions(-) mode change 100755 => 100644 example/ck_tile/03_gemm/gemm_basic.cpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100755 new mode 100644 diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 973006196b..25fab6bde0 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -55,17 +55,17 @@ struct GemmConfig #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; static constexpr bool DoubleSmemBuffer = false; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index b4ea5d22c0..79ed9ce76b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -402,7 +402,6 @@ int run_gemm_example_with_layouts(int argc, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 2c29814b73..bd7a0566a2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -24,9 +24,14 @@ using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -49,10 +54,16 @@ using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; // bf16 - using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -87,9 +97,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -113,10 +128,16 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl +struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; // FP16 template struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 @@ -188,6 +251,69 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 } }; +template +struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 { From 94662b02d0456bd29c7d3c36eeff39a0f7f49eed Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Tue, 22 Apr 2025 15:55:19 -0700 Subject: [PATCH 062/443] Adding include directory in tile_engine (#2116) --- tile_engine/include/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) create mode 100755 tile_engine/include/CMakeLists.txt diff --git a/tile_engine/include/CMakeLists.txt b/tile_engine/include/CMakeLists.txt new file mode 100755 index 0000000000..d11a4b3bee --- /dev/null +++ b/tile_engine/include/CMakeLists.txt @@ -0,0 +1 @@ +message("Add include directory") From 39ba03f25d4c4c4e9f551a2dcf001cadd0b86cbe Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 23 Apr 2025 10:35:34 +0800 Subject: [PATCH 063/443] Moe gemm activation (#2026) * fix useless code and remove usless oob * clang format * fix coredump in e2e test * fix2 * fix clang format * fix output oob * impl int64 but result not correct * int64 index ok now * input output all ok * fix uint32 * revert v1 test * use uint32 * mork to support 13w tokens * moe sorting fix moebuf * fix merge * update moe api fix aiter build * fix buid * fuse silu * silu ok * acale ok * add silu * change code * gemm2 ok * gufusion compatible ok, fix warnings * gu fusion for m32 m64 ok * support bf16 cshuffle * i4 gemm2 ok * i4 gemm2 ok and i4 gemm1 build * 16x16 run ok * change flops; change cshuffle dtype * fuse gelu silu act in moe gemm1 * fp8 with act ready * int4 act ready * remove useless changes * remove useless code change * fix clang format * add the arch limit of int4 moe gemm * fuse moe activation * fix fp8 16x16 * fix no quant case * fix bugs * fix fp8 gufusion bug * remove useless comments * refine activation code & complete moe example * fix int8 bugs * merge tkw1 --------- Co-authored-by: coderfeli Co-authored-by: feli Co-authored-by: illsilin Co-authored-by: root Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .../65_gemm_multiply_multiply/CMakeLists.txt | 6 + .../moe_gemm1_xdl_fp8.cpp | 163 +++-- .../moe_gemm1_xdl_pk_i4.cpp | 166 ++--- .../moe_gemm2_xdl_fp8.cpp | 82 +-- .../moe_gemm2_xdl_pk_i4.cpp | 19 +- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 621 ++++++++++++++++++ ...peline_xdlops_b_preshuffle_gufusion_v1.hpp | 573 ++++++++++++++++ ..._pipeline_xdlops_b_preshuffle_selector.hpp | 141 ++-- .../blockwise_gemm_pipeline_xdlops_base.hpp | 5 +- ...roup_tensor_slice_transfer_v4r1_gather.hpp | 4 +- ...oup_tensor_slice_transfer_v7r3_scatter.hpp | 14 +- .../gpu/device/impl/device_moe_gemm.hpp | 26 +- .../gpu/grid/gridwise_moe_gemm.hpp | 444 +++++++++---- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 7 +- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 46 +- include/ck/utility/dynamic_buffer.hpp | 60 +- include/ck/utility/tuple_helper.hpp | 7 + .../cpu/reference_moe_gemm.hpp | 85 ++- .../cpu/reference_moe_gemm2.hpp | 2 +- 19 files changed, 1975 insertions(+), 496 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index deca85ae64..3c1947c058 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -13,6 +13,12 @@ foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132) + set(EXAMPLE_COMPILE_OPTIONS) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() set(target 1) endif() endforeach() diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index f594080755..3b31460953 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -61,27 +60,25 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - e = ck::type_convert(c * d1 * d0); + (void)d0; + (void)d1; + e = ck::type_convert(c); } -}; - -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - // act - float x0 = 0; - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); - e = ck::type_convert(x0); + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); } }; @@ -95,11 +92,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix + (void)d0; + (void)d1; (void)d2; - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -107,16 +112,14 @@ struct MulABScaleExpertWeight float& e, const float& c, const float& d0, const float& d1, const float& d2) const { // for reference cpu - e = ck::type_convert(c * d0 * d1 * d2); + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true -// using DsLayout = DsLayoutGate; -// using DsDataType = DsDataTypeGate; -// using CDEElementOp = MulABScale; // combine MulRoutedWeight = false - -// using CDEElementOp = MulABScaleSiluMulGate; +using CDEElementOp = MulABScaleExpertWeight; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { @@ -155,22 +158,21 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = true; -static constexpr bool MulRoutedWeight = false; +static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -static constexpr ck::index_t D2Vec = 1; -// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -188,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -201,15 +203,13 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; - ck::index_t tokens = 128; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; + ck::index_t tokens = 64; ck::index_t topk = 2; - // ck::index_t tokens = batch * topk; - if(argc == 1) { // use default case @@ -255,28 +255,22 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -290,13 +284,12 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( @@ -304,6 +297,7 @@ int main(int argc, char* argv[]) std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -312,25 +306,25 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -349,9 +343,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k.savetxt("a.txt"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -369,7 +361,8 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); - preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -408,9 +401,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) * K * N * experts + + sizeof(B0DataType) * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -437,6 +430,7 @@ int main(int argc, char* argv[]) PassThrough, PassThrough, PassThrough, + ActOP, MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -446,7 +440,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, d2_e_n, PassThrough{}, @@ -472,15 +468,14 @@ int main(int argc, char* argv[]) c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n), - 1.f); + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index fb8a8b9826..3c3ef16198 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -36,7 +36,7 @@ using A0DataType = F8; using B0DataType = I4; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -47,7 +47,8 @@ using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -56,42 +57,32 @@ struct MulABScale __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } template <> __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); + e = ck::type_convert(c); #else - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); #endif } }; -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const - { - // act - float x0 = 0; -#if CK_USE_PK4_LAYOUT_SHUFFLE - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); -#else - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); -#endif - e = ck::type_convert(x0); - } -}; - struct MulABScaleExpertWeight { template @@ -102,13 +93,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { + (void)d0; + (void)d1; (void)d2; - -#if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); -#else - e = ck::type_convert(c * d1 * d0); -#endif + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -116,15 +113,18 @@ struct MulABScaleExpertWeight float& e, const float& c, const float& d0, const float& d1, const float& d2) const { // for reference cpu -#if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d0 * d1 * d2 * 16); -#else - e = ck::type_convert(c * d0 * d1 * d2); -#endif + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScaleExpertWeight; +static constexpr bool MulRoutedWeight = true; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true + +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -165,54 +165,24 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 -static constexpr ck::index_t MPerBlock = 64; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< - Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - AK1, BK1, - MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; -// clang-format on -#else static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; -static constexpr bool MulRoutedWeight = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, MPerBlock, 128, 128, + 256, MPerBlock, 64, 128, 16, 32, - 32, 32, - 4, 1, + 16, 16, + 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, MulRoutedWeight, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; // clang-format on -#endif int main(int argc, char* argv[]) { @@ -220,13 +190,10 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape - ck::index_t N = 4096 * 2; - ck::index_t K = 6144; + ck::index_t N = 14336; + ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; @@ -266,20 +233,20 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0}; + constexpr auto StrideDs = std::array{0, 0, 0}; ck::index_t KBatch = 1; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0}; + max_token_id.mData = {valid_size}; int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; for(int i = 0; i < sorted_size; i++) { @@ -294,11 +261,12 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( @@ -306,6 +274,7 @@ int main(int argc, char* argv[]) std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -314,11 +283,11 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_e_n.GenerateTensorValue(GeneratorTensor_3{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -497,9 +466,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * experts + + sizeof(B0DataType) / 2 * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -526,6 +495,7 @@ int main(int argc, char* argv[]) PassThrough, PassThrough, PassThrough, + Act_OP, MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -535,7 +505,9 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, d2_e_n, PassThrough{}, @@ -561,13 +533,13 @@ int main(int argc, char* argv[]) c_t_k_n(t, topk_id, n), d0_t_n(t, n), d1_e_n(e, n), - 1.f); + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 04f10b53ae..42d892fe26 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -48,7 +47,6 @@ using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; -// using DsLayoutGate = ck::Tuple; using DsLayout = ck::Tuple; // d0: ascale, d1: bscale, d2:expert weight @@ -62,11 +60,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix (void)d0; - e = ck::type_convert(c * d1 * d2); + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -119,14 +125,12 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint -// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -135,7 +139,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; -static constexpr bool MulRoutedWeight = false; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -164,8 +168,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; + 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -177,16 +181,13 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -212,6 +213,18 @@ int main(int argc, char* argv[]) K = std::stoi(argv[5]); tokens = std::stoi(argv[6]); } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); @@ -229,15 +242,13 @@ int main(int argc, char* argv[]) ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); - // max_token_id.mData[0] = valid_size; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + + max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; @@ -249,7 +260,7 @@ int main(int argc, char* argv[]) } int token_per_tile = tokens * topk / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -263,8 +274,7 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); @@ -315,12 +325,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); - // d2_e_n.savetxt("d2_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -398,7 +403,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2(c * d1 * d2 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d2); + e = ck::type_convert(c); #endif } // for reference cpu @@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -149,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, MulRoutedWeight, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -159,9 +161,6 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp new file mode 100644 index 0000000000..29750b8baa --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + StaticallyIndexedArray{}> + b_thread_dequant_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I0)); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up + [mfma_reg_buf][Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I1)); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp new file mode 100644 index 0000000000..73749c6309 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // if constexpr(i.value < num_buffer_load_inst_a) { + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // } + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index a94ef03008..074b5873ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -3,8 +3,10 @@ #pragma once +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" @@ -33,57 +35,112 @@ template + index_t KPack, + bool GUFusion = false> constexpr auto BlockGemmBPreshufflePipeline_Selector() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } } else { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index d7ba2559ea..ce507ca8d3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); static constexpr auto xdlops_gemm = XdlopsGemm{}; @@ -333,7 +334,7 @@ struct BlockwiseGemmXdlops_pipeline_base return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( c_grid_desc_g_m0_n0_m1_n1_m2_n2); } - + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp index 859649185a..92aef65388 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_gather @@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather const DstDesc& dst_desc, const Index& dst_block_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : threadwise_transfer_(src_desc, make_zero_multi_index(), src_element_op, @@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, + IndexType, GatherDim, NumThreadScratch>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index cf758e4d5f..bee0b01a74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,6 +42,7 @@ template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); } } @@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -169,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } @@ -230,6 +229,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, + IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 03db4bdd41..08d177035e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -65,9 +65,12 @@ template ; RunKernel(kernel); } @@ -281,8 +287,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -297,8 +301,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -308,8 +310,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -329,8 +329,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index a2d1114bbe..255fb8cff4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" @@ -26,12 +26,17 @@ namespace ck { // two lds chunks. // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -45,22 +50,19 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run(karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -70,8 +72,6 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -86,23 +86,20 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run_2Lds(karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -154,7 +151,12 @@ template ) @@ -497,8 +500,8 @@ struct GridwiseMoeGemm } template - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) @@ -909,7 +912,8 @@ struct GridwiseMoeGemm NPerXdl, MXdlPerWave, NXdlPerWave, - KPack>())>; + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1141,9 +1145,7 @@ struct GridwiseMoeGemm template + TailNumber TailNum = TailNumber::Odd> __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1203,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1218,7 +1221,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1226,9 +1229,10 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1239,7 +1243,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1269,6 +1272,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1311,24 +1315,74 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -1356,6 +1410,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1453,17 +1686,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -1526,7 +1750,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -1538,7 +1763,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -1568,35 +1792,21 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - if constexpr(MulRoutedWeight) - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - if constexpr(sizeof(ADataType) < 2) - weight = p_sorted_weights_2[c_token_pos + m0] * weight; - else - weight = p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -1604,7 +1814,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); @@ -1617,8 +1827,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -1643,9 +1852,7 @@ struct GridwiseMoeGemm template + TailNumber TailNum = TailNumber::Odd> __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, const index_t* p_max_token_id, @@ -1721,7 +1928,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; @@ -1730,7 +1937,7 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1773,6 +1980,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, 2>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1967,11 +2175,12 @@ struct GridwiseMoeGemm const DDataType* ptr_ = p_ds_grid[i]; // hack logic here to support different kind of strides. todo fix it. // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } + // if(i.value == 1) + // { + // ptr_ += + // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : + // 1); + // } return make_dynamic_buffer( ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); }, @@ -2036,7 +2245,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -2078,12 +2288,9 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2091,23 +2298,11 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - if constexpr(MulRoutedWeight) - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - if constexpr(sizeof(ADataType) < 2) - weight = p_sorted_weights_2[c_token_pos + m0] * weight; - else - weight = p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -2128,8 +2323,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index bb9a452761..bd6fe772e4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadwiseTensorSliceTransfer_v3r1_gather @@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather const DstDesc& dst_desc, const Index& dst_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), src_element_op_(src_element_op), @@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - const index_t ld_offset = src_coord_.GetOffset() + gather_offset; + const IndexType ld_offset = src_coord_.GetOffset() + gather_offset; src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, true); @@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; - StaticallyIndexedArray gather_offsets_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 6a1c195dc1..7cd0a0fc7f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,6 +43,7 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence + typename IndexType, index_t ScatterDim = 1, bool OutputScatter = true, index_t ScatterWeightIdx = 3, @@ -153,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -172,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if(i.value == ScatterWeightIdx) - { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, - "scatter weight dim, should only one vec"); - constexpr auto iScatter = - SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = - scatter_weights(Number{}); - }); - } - else if constexpr(SrcScalarPerVectors{}[i] == 1) - { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; - const auto tmp = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); - } - else - { - src_vectors(i).template AsType()(I0) = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - } + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); }); constexpr auto get_elem_op_vec_len = []() { @@ -412,7 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -420,8 +397,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; - auto scatter_offset = 0; + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + IndexType scatter_offset = 0; if constexpr(OutputScatter) { constexpr auto iScatter = @@ -431,8 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); dst_bufs(i).template Update( @@ -488,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 1a0ea27eab..1d80f196b5 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -24,7 +24,8 @@ template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + typename IndexType = index_t> struct DynamicBuffer { using type = T; @@ -59,16 +60,16 @@ struct DynamicBuffer return BufferAddressSpace; } - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; } - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; } template >::type, typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -79,7 +80,7 @@ struct DynamicBuffer "wrong! X should contain multiple T"); #if CK_USE_AMD_BUFFER_LOAD - bool constexpr use_amd_buffer_addressing = true; + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -140,7 +141,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::Set) { @@ -191,8 +192,8 @@ struct DynamicBuffer template __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, - index_t src_offset, - index_t dst_offset, + IndexType src_offset, + IndexType dst_offset, bool is_valid_element) const { // Copy data from global to LDS memory using direct loads. @@ -214,7 +215,7 @@ struct DynamicBuffer typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -224,8 +225,8 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); -#if CK_USE_AMD_BUFFER_STORE - bool constexpr use_amd_buffer_addressing = true; +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -342,11 +343,12 @@ struct DynamicBuffer { if(is_valid_element) { -#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS +#if 0 X tmp = x; __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #else + // if(i >= 2169041600) *c_style_pointer_cast(&p_data_[i]) = x; #endif } @@ -357,7 +359,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x) { using scalar_t = typename scalar_type>::type; @@ -378,12 +380,14 @@ struct DynamicBuffer (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) - bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = - is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || - (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); + sizeof(IndexType) <= sizeof(int32_t) && + (is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0)); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -408,12 +412,12 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -421,8 +425,9 @@ struct DynamicBuffer static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 - using scalar_t = typename scalar_type>::type; - bool constexpr use_amd_buffer_addressing = is_same_v, double>; + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, double>; #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -455,6 +460,17 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el p, element_space_size}; } +template +__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p, + ElementSpaceSize element_space_size) +{ + return DynamicBuffer{ + p, element_space_size}; +} + template < AddressSpaceEnum BufferAddressSpace, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index b1a0c1fc5d..ec055fb2a2 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -23,6 +23,13 @@ __host__ __device__ constexpr auto generate_tuple(F&& f, Number) return generate_tuple_for(f, make_index_sequence{}); } +template +__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + template __host__ __device__ constexpr auto generate_tie(F&& f, Number) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 72c9dc86ac..120bf7484a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -23,12 +23,14 @@ template + index_t ActivationType_ = 0, + bool MulRoutedWeight = true, + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct ReferenceMoeGemm : public device::BaseOperator { // Argument + static constexpr auto ActivationType = ActivationType_; struct Argument : public device::BaseArgument { Argument(const Tensor& sorted_token_ids, @@ -36,7 +38,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_t, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, const Tensor& d2, AElementwiseOperation a_element_op, @@ -47,7 +51,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id_{max_token_id}, sorted_tile_size_{sorted_tile_size}, a_t_k_{a_t_k}, + a_scale_t_{a_scale_t}, b_e_n_k_{b_e_n_k}, + b_scale_e_n_{b_scale_e_n}, c_t_k_n_{c_t_k_n}, d2_{d2}, a_element_op_{a_element_op}, @@ -61,7 +67,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id_; index_t sorted_tile_size_; const Tensor& a_t_k_; + const Tensor& a_scale_t_; const Tensor& b_e_n_k_; + const Tensor& b_scale_e_n_; Tensor& c_t_k_n_; const Tensor& d2_; @@ -77,11 +85,17 @@ struct ReferenceMoeGemm : public device::BaseOperator float Run(const Argument& arg) { - auto f_mk_kn_mn = [&](auto m, auto n) { + static_assert(ActivationType < 2, "Not supported activation type"); + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc_up{0}; + ComputeTypeB v_b_up{0}; AccDataType v_acc{0}; + ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); @@ -102,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else @@ -112,42 +126,79 @@ struct ReferenceMoeGemm : public device::BaseOperator // same for B matrix if constexpr(is_same_v) { - uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; - uint8_t i4 = 0; + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + uint8_t i4 = 0; + uint8_t i4_up = 0; if(k % 2 == 1) - i4 = (i4x2 >> 0) & 0xf; + { + i4 = (i4x2 >> 0) & 0xf; + i4_up = (i4x2_up >> 0) & 0xf; + } else - i4 = (i4x2 >> 4) & 0xf; + { + i4 = (i4x2 >> 4) & 0xf; + i4_up = (i4x2_up >> 4) & 0xf; + } #if CK_USE_PK4_LAYOUT_SHUFFLE - v_b = i4_to_f32_gfx9(i4); + v_b = i4_to_f32_gfx9(i4); + v_b_up = i4_to_f32_gfx9(i4_up); #else - v_b = i4 - 8; + v_b = i4 - 8; + v_b_up = i4_up - 8; #endif } else { arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); } CDataType v_c{0}; - + CDataType v_c_up{0}; if constexpr(MulRoutedWeight) { v_acc *= v_topk_w; + v_acc_up *= v_topk_w; } arg.c_element_op_(v_c, v_acc); + arg.c_element_op_(v_c_up, v_acc_up); - arg.c_t_k_n_(t, topk_id, n) = v_c; + if constexpr(ActivationType == 1) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Silu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else if constexpr(ActivationType == 0) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Gelu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } } }; const ck::index_t max_token_id = arg.max_token_id_(0); - make_ParallelTensorFunctor( - f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])( + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( std::thread::hardware_concurrency()); return 0; @@ -173,7 +224,9 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_n, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, const Tensor& d2, AElementwiseOperation a_element_op, @@ -185,7 +238,9 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id, sorted_tile_size, a_t_k, + a_scale_n, b_e_n_k, + b_scale_e_n, c_t_k_n, d2, a_element_op, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index fb5c71e30a..5c932fcb18 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -25,7 +25,7 @@ template struct ReferenceMoeGemm2 : public device::BaseOperator From 854159fd00d43736bc7c69a491f30006a5dce67e Mon Sep 17 00:00:00 2001 From: John Afaganis Date: Wed, 23 Apr 2025 11:25:41 -0600 Subject: [PATCH 064/443] Update CODEOWNERS (#2119) --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index eb69bd7f39..ccdfb0f6fb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli From 5487289fc479c875b181152c0383fdf1da7b2f00 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 24 Apr 2025 03:40:18 +0800 Subject: [PATCH 065/443] [CK_TILE] support gfx950 matrix core in 01_fmha fwd (#2110) * gfx950 01_fmha fwd * fix comment --------- Co-authored-by: Thomas Ning --- include/ck_tile/ops/gemm.hpp | 3 + include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 24 ++ .../gemm/warp/warp_gemm_attribute_mfma.hpp | 27 ++- .../warp/warp_gemm_attribute_mfma_impl.hpp | 229 ++++++++++++++++++ .../gemm/warp/warp_gemm_attribute_smfmac.hpp | 5 + 5 files changed, 286 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 794f7f21f2..35f5170179 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -44,8 +44,11 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index bd7a0566a2..e6350a8827 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -49,10 +49,16 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = @@ -65,10 +71,16 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = 2>>; #endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, @@ -123,10 +135,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = @@ -139,10 +157,16 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = 2>>; #endif +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index e7d4c37966..93ccdb5f57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -356,7 +356,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution } }; -template +template struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -373,6 +373,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -386,7 +387,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB tuple>, sequence<2>, sequence<1>>; - +#if 0 using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>, sequence<2, 2>, sequence<0, 2>>; +#else + // TODO: more test not only 32x32 + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; +#endif template // c_vec += a_vec * b_vec CK_TILE_DEVICE void operator()(CVecType& c_vec, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index f937899ffd..08f813a1e3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -748,6 +748,235 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 } }; +// gfx950 +template +struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // FP8 template struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp index adf548aaca..84cdf17d66 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -1,3 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" From 02ce6d39ea11b06d583da04a5d3feb4cb66a55a0 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 24 Apr 2025 18:52:58 +0800 Subject: [PATCH 066/443] Only generate specific hdim (#2120) --- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 8 +++++-- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 13 ++++++---- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 6 +++-- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 10 ++++---- example/ck_tile/01_fmha/generate.py | 24 ++++++++++++++----- 5 files changed, 42 insertions(+), 19 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 1e6755c631..932f6020b6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -866,9 +866,11 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: @@ -881,9 +883,11 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3634810b37..c31a0ce954 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -429,7 +429,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -507,6 +507,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] @@ -557,15 +560,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f243020dc4..dc7ef712e2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -343,13 +343,15 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] with file_path.open('a') as f: _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 0dccdf6bd6..ca49af1496 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -440,10 +440,10 @@ class FmhaFwdSplitKVCombinePipeline: n = f'{self.tag}' if pn != '' : n += f'_{pn}' else: n += '_npad' - + if self.F_lse == 't' : n += '_lse' else: n += '_nlse' - + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' return n @@ -819,9 +819,10 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: @@ -831,9 +832,10 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 25931da141..c2b0924eb3 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -30,7 +30,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -40,10 +40,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, mask_impl) + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) @@ -52,7 +52,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, mask_impl) + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -113,12 +113,24 @@ if __name__ == "__main__": " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" ) + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" + ) + args = parser.parse_args() api_list = args.direction.split(',') filter_list = args.filter.split(',') filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if len(api_list) > 1: + assert optdim_list == [-1] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) From ba97363acd615efba5a4c5e3e0553c3ee14e566f Mon Sep 17 00:00:00 2001 From: alexxu-amd <159800977+alexxu-amd@users.noreply.github.com> Date: Thu, 24 Apr 2025 11:35:06 -0400 Subject: [PATCH 067/443] Setup Doxygen API reference for Docs (#2115) * setup Doxygen settings * add api_reference to requirements.txt * add doxygen file header * omit latex generation * remove testing entry * update Doxyfile --- docs/conf.py | 1 + docs/doxygen/Doxyfile | 938 +++++++++++++++++++++++------------ docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 143 ++++-- 4 files changed, 724 insertions(+), 360 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e8617a09ef..fe8a1c1d79 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,7 @@ external_toc_path = "./sphinx/_toc.yml" docs_core = ROCmDocs(left_nav_title) docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") +docs_core.enable_api_reference() docs_core.setup() external_projects_current_project = "composable_kernel" diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index fac9e138e1..d6f38e0ca9 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -1,4 +1,4 @@ -# Doxyfile 1.8.10 +# Doxyfile 1.9.7 # This file describes the settings to be used by the documentation system # doxygen (www.doxygen.org) for a project. @@ -12,16 +12,26 @@ # For lists, items can also be appended using: # TAG += value [value, ...] # Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. # The default value is: UTF-8. DOXYFILE_ENCODING = UTF-8 @@ -44,14 +54,14 @@ PROJECT_NUMBER = v3.0.1.0 # for a project that appears at the top of each page and should give viewer a # quick idea about the purpose of the project. Keep the description short. -PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HiP" +PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HIP" # With the PROJECT_LOGO tag one can specify a logo or an icon that is included # in the documentation. The maximum height of the logo should not exceed 55 # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -60,16 +70,28 @@ PROJECT_LOGO = OUTPUT_DIRECTORY = . -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this # option can be useful when feeding doxygen a huge amount of source files, where # putting all generated files in the same directory would otherwise causes -# performance problems for the file system. +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. # The default value is: NO. CREATE_SUBDIRS = NO +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + # If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII # characters to appear in the names of generated files. If set to NO, non-ASCII # characters will be escaped, for example _xE3_x81_x84 will be used for Unicode @@ -81,14 +103,14 @@ ALLOW_UNICODE_NAMES = NO # The OUTPUT_LANGUAGE tag is used to specify the language in which all # documentation generated by doxygen is written. Doxygen will use this # information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. # The default value is: English. OUTPUT_LANGUAGE = English @@ -162,7 +184,8 @@ FULL_PATH_NAMES = YES # will be relative from the directory where doxygen is started. # This tag requires that the tag FULL_PATH_NAMES is set to YES. -STRIP_FROM_PATH = +#STRIP_FROM_PATH = +STRIP_FROM_PATH = /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/latest/ # The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the # path mentioned in the documentation of a class, which tells the reader which @@ -171,7 +194,8 @@ STRIP_FROM_PATH = # specify the list of include paths that are normally passed to the compiler # using the -I flag. -STRIP_FROM_INC_PATH = +STRIP_FROM_INC_PATH = + # If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but # less readable) file names. This can be useful is your file systems doesn't @@ -189,6 +213,16 @@ SHORT_NAMES = NO JAVADOC_AUTOBRIEF = NO +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + # If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first # line (until the first dot) of a Qt-style comment as the brief description. If # set to NO, the Qt-style will behave just like regular Qt-style comments (thus @@ -209,6 +243,14 @@ QT_AUTOBRIEF = NO MULTILINE_CPP_IS_BRIEF = NO +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + # If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the # documentation from any documented member that it re-implements. # The default value is: YES. @@ -232,20 +274,19 @@ TAB_SIZE = 4 # the documentation. An alias has the form: # name=value # For example adding -# "sideeffect=@par Side Effects:\n" +# "sideeffect=@par Side Effects:^^" # will allow you to put the command \sideeffect (or @sideeffect) in the # documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) ALIASES = -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - # Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources # only. Doxygen will then generate output that is more tailored for C. For # instance, some of the names that are used will be different. The list of all @@ -274,28 +315,40 @@ OPTIMIZE_FOR_FORTRAN = NO OPTIMIZE_OUTPUT_VHDL = NO +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + # Doxygen selects the parser to use depending on the extension of the files it # parses. With this tag you can assign which parser to use for a given # extension. Doxygen has a built-in mapping, but you can override or extend it # using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. # # Note: For files without extension you can use no_extension as a placeholder. # # Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. EXTENSION_MAPPING = # If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments # according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. +# documentation. See https://daringfireball.net/projects/markdown/ for details. # The output of markdown processing is further processed by doxygen, so you can # mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in # case of backward compatibilities issues. @@ -303,6 +356,26 @@ EXTENSION_MAPPING = MARKDOWN_SUPPORT = YES +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN Use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0. and GITHUB Use the lower case version of title +# with any whitespace replaced by '-' and punctations characters removed.. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + # When enabled doxygen tries to link words that correspond to documented # classes, or namespaces to their corresponding documentation. Such a link can # be prevented in individual cases by putting a % sign in front of the word or @@ -328,7 +401,7 @@ BUILTIN_STL_SUPPORT = YES CPP_CLI_SUPPORT = NO # Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen # will parse them like normal C++ but will assume all classes use public instead # of private inheritance when no explicit protection keyword is present. # The default value is: NO. @@ -414,6 +487,27 @@ TYPEDEF_HIDES_STRUCT = YES LOOKUP_CACHE_SIZE = 0 +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = YES + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- @@ -434,6 +528,12 @@ EXTRACT_ALL = YES EXTRACT_PRIVATE = NO +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + # If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal # scope will be included in the documentation. # The default value is: NO. @@ -471,6 +571,13 @@ EXTRACT_LOCAL_METHODS = NO EXTRACT_ANON_NSPACES = NO +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + # If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all # undocumented members inside documented classes or files. If set to NO these # members will be included in the various overviews, but no documentation @@ -482,14 +589,15 @@ HIDE_UNDOC_MEMBERS = NO # If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set # to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. # The default value is: NO. HIDE_UNDOC_CLASSES = NO # If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. +# declarations. If set to NO, these declarations will be included in the +# documentation. # The default value is: NO. HIDE_FRIEND_COMPOUNDS = NO @@ -508,12 +616,20 @@ HIDE_IN_BODY_DOCS = NO INTERNAL_DOCS = NO -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. CASE_SENSE_NAMES = NO @@ -531,6 +647,12 @@ HIDE_SCOPE_NAMES = NO HIDE_COMPOUND_REFERENCE= NO +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + # If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of # the files that are included by a file in the documentation of that file. # The default value is: YES. @@ -688,7 +810,8 @@ FILE_VERSION_FILTER = # output files in an output format independent way. To create the layout file # that represents doxygen's defaults, run doxygen with the -l option. You can # optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. # # Note that if you run doxygen from a directory containing a file called # DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE @@ -699,7 +822,7 @@ LAYOUT_FILE = # The CITE_BIB_FILES tag can be used to specify one or more bib files containing # the reference definitions. This must be a list of .bib files. The .bib # extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. # For LaTeX the style of the bibliography can be controlled using # LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the # search path. See also \cite for info how to create references. @@ -734,34 +857,81 @@ WARNINGS = YES WARN_IF_UNDOCUMENTED = YES # If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. # The default value is: YES. WARN_IF_DOC_ERROR = YES +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + # This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that # are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC # The default value is: NO. WARN_NO_PARAMDOC = NO +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. +# The default value is: NO. + +WARN_AS_ERROR = NO + # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which # will be replaced by the file and line number from which the warning originated # and the warning text. Optionally the format may contain $version, which will # be replaced by the version of the file (if it could be obtained via # FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT # The default value is: $file:$line: $text. WARN_FORMAT = "$file:$line: $text" +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + # The WARN_LOGFILE tag can be used to specify a file to which warning and error # messages should be written. If left blank the output is written to standard -# error (stderr). +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). WARN_LOGFILE = @@ -785,12 +955,23 @@ INPUT = ../../include/ck/tensor_operation/gpu/grid \ # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING # The default value is: UTF-8. INPUT_ENCODING = UTF-8 +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + # If the value of the INPUT tag contains directories, you can use the # FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and # *.h) to filter out the source-files in the directories. @@ -799,11 +980,15 @@ INPUT_ENCODING = UTF-8 # need to set EXTENSION_MAPPING for the extension otherwise the files are not # read by doxygen. # +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# # If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, # *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, -# *.vhdl, *.ucf, *.qsf, *.as and *.js. +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. FILE_PATTERNS = *.c \ *.cc \ @@ -824,6 +1009,7 @@ FILE_PATTERNS = *.c \ *.hxx \ *.hpp \ *.h++ \ + *.l \ *.cs \ *.d \ *.php \ @@ -837,13 +1023,19 @@ FILE_PATTERNS = *.c \ *.mm \ *.dox \ *.py \ - *.tcl \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f18 \ + *.f \ + *.for \ *.vhd \ *.vhdl \ *.ucf \ *.qsf \ - *.as \ - *.js + *.ice # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. @@ -880,10 +1072,7 @@ EXCLUDE_PATTERNS = # (namespaces, classes, functions, etc.) that should be excluded from the # output. The symbol name can be a fully qualified name, a word, or if the # wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* +# ANamespace::AClass, ANamespace::*Test EXCLUDE_SYMBOLS = @@ -927,6 +1116,15 @@ IMAGE_PATH = # Note that the filter must not add or remove lines; it is applied before the # code is scanned, but not when the output code is generated. If lines are added # or removed, the anchors will not be placed correctly. +# +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. INPUT_FILTER = @@ -936,6 +1134,10 @@ INPUT_FILTER = # (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how # filters are used. If the FILTER_PATTERNS tag is empty or if none of the # patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. FILTER_PATTERNS = @@ -959,7 +1161,16 @@ FILTER_SOURCE_PATTERNS = # (index.html). This can be useful if you have a project on for instance GitHub # and want to reuse the introduction page also for the doxygen output. -USE_MDFILE_AS_MAINPAGE = ../README.md +USE_MDFILE_AS_MAINPAGE = ../../README.md + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 #--------------------------------------------------------------------------- # Configuration options related to source browsing @@ -988,7 +1199,7 @@ INLINE_SOURCES = NO STRIP_CODE_COMMENTS = YES # If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. +# entity all documented functions referencing it will be listed. # The default value is: NO. REFERENCED_BY_RELATION = NO @@ -1020,12 +1231,12 @@ SOURCE_TOOLTIPS = YES # If the USE_HTAGS tag is set to YES then the references to source code will # point to the HTML generated by the htags(1) tool instead of doxygen built-in # source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version +# (see https://www.gnu.org/software/global/global.html). You will need version # 4.8.6 or higher. # # To use it do the following: # - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file # - Make sure the INPUT points to the root of the source tree # - Run doxygen as normal # @@ -1047,25 +1258,6 @@ USE_HTAGS = NO VERBATIM_HEADERS = YES -# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the -# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the -# cost of reduced performance. This can be particularly helpful with template -# rich C++ code for which doxygen's built-in parser lacks the necessary type -# information. -# Note: The availability of this option depends on whether or not doxygen was -# compiled with the --with-libclang option. -# The default value is: NO. - -CLANG_ASSISTED_PARSING = NO - -# If clang assisted parsing is enabled you can provide the compiler with command -# line options that you would normally use when invoking the compiler. Note that -# the include paths will already be set by doxygen for the files and directories -# specified with INPUT and INCLUDE_PATH. -# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. - -CLANG_OPTIONS = - #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- @@ -1077,17 +1269,11 @@ CLANG_OPTIONS = ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. # This tag requires that the tag ALPHABETICAL_INDEX is set to YES. IGNORE_PREFIX = @@ -1134,7 +1320,7 @@ HTML_FILE_EXTENSION = .html # of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_HEADER = +HTML_HEADER = ../_doxygen/header.html # The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each # generated HTML page. If the tag is left blank doxygen will generate a standard @@ -1144,7 +1330,7 @@ HTML_HEADER = # that doxygen normally uses. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_FOOTER = +HTML_FOOTER = ../_doxygen/footer.html # The HTML_STYLESHEET tag can be used to specify a user-defined cascading style # sheet that is used by each HTML page. It can be used to fine-tune the look of @@ -1156,7 +1342,7 @@ HTML_FOOTER = # obsolete. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_STYLESHEET = +HTML_STYLESHEET = ../_doxygen/stylesheet.css # The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined # cascading style sheets that are included after the standard style sheets @@ -1166,10 +1352,15 @@ HTML_STYLESHEET = # Doxygen will copy the style sheet files to the output directory. # Note: The order of the extra style sheet files is of importance (e.g. the last # style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_STYLESHEET = +HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css # The HTML_EXTRA_FILES tag can be used to specify one or more extra images or # other source files which should be copied to the HTML output directory. Note @@ -1181,19 +1372,32 @@ HTML_EXTRA_STYLESHEET = HTML_EXTRA_FILES = +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = LIGHT + # The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen # will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value # 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 # purple, and 360 is red again. # Minimum value: 0, maximum value: 359, default value: 220. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_COLORSTYLE_HUE = 220 +HTML_COLORSTYLE_HUE = 240 # The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A +# in the HTML output. For a value of 0 the output will use gray-scales only. A # value of 255 will produce the most vivid colors. # Minimum value: 0, maximum value: 255, default value: 100. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1211,14 +1415,16 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_TIMESTAMP = NO +HTML_DYNAMIC_MENUS = YES # If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML # documentation will contain sections that can be hidden and shown after the @@ -1243,13 +1449,14 @@ HTML_INDEX_NUM_ENTRIES = 100 # If the GENERATE_DOCSET tag is set to YES, additional index files will be # generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in # ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1263,6 +1470,13 @@ GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + # This tag specifies a string that should uniquely identify the documentation # set bundle. This should be a reverse domain-name style string, e.g. # com.mycompany.MyDocSet. Doxygen will append .docset to the name. @@ -1288,8 +1502,12 @@ DOCSET_PUBLISHER_NAME = Publisher # If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three # additional HTML index files: index.hhp, index.hhc, and index.hhk. The # index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). # # The HTML Help Workshop contains a compiler that can convert all HTML output # generated by doxygen into a single compiled HTML file (.chm). Compiled HTML @@ -1319,7 +1537,7 @@ CHM_FILE = HHC_LOCATION = # The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). +# (YES) or that it should be included in the main .chm file (NO). # The default value is: NO. # This tag requires that the tag GENERATE_HTMLHELP is set to YES. @@ -1346,6 +1564,16 @@ BINARY_TOC = NO TOC_EXPAND = NO +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + # If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and # QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that # can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help @@ -1364,7 +1592,8 @@ QCH_FILE = # The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help # Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). # The default value is: org.doxygen.Project. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1372,8 +1601,8 @@ QHP_NAMESPACE = org.doxygen.Project # The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt # Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). # The default value is: doc. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1381,30 +1610,30 @@ QHP_VIRTUAL_FOLDER = doc # If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom # filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_NAME = # The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the # custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_ATTRS = # The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this # project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_SECT_FILTER_ATTRS = -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. # This tag requires that the tag GENERATE_QHP is set to YES. QHG_LOCATION = @@ -1447,16 +1676,28 @@ DISABLE_INDEX = NO # to work a browser that supports JavaScript, DHTML, CSS and frames is required # (i.e. any modern browser). Windows users are probably better off using the # HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. GENERATE_TREEVIEW = NO +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + # The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that # doxygen will group on one line in the generated HTML documentation. # @@ -1481,6 +1722,24 @@ TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + # Use this tag to change the font size of LaTeX formulas included as images in # the HTML documentation. When you change the font size after a successful # doxygen run you need to manually remove any form_*.png images from the HTML @@ -1490,19 +1749,14 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. -FORMULA_TRANSPARENT = YES +FORMULA_MACROFILE = # Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering +# https://www.mathjax.org) which uses client side JavaScript for the rendering # instead of using pre-rendered bitmaps. Use this if you do not have LaTeX # installed or if you want to formulas look prettier in the HTML output. When # enabled you may also need to install MathJax separately and configure the path @@ -1512,11 +1766,29 @@ FORMULA_TRANSPARENT = YES USE_MATHJAX = YES +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + # When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). # Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. # The default value is: HTML-CSS. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1529,22 +1801,29 @@ MATHJAX_FORMAT = HTML-CSS # MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax # Content Delivery Network so you can quickly see the result without installing # MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 # This tag requires that the tag USE_MATHJAX is set to YES. -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest +MATHJAX_RELPATH = # The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax # extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): # MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams # This tag requires that the tag USE_MATHJAX is set to YES. MATHJAX_EXTENSIONS = # The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces # of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an # example see the documentation. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1572,7 +1851,7 @@ MATHJAX_CODEFILE = SEARCHENGINE = YES # When the SERVER_BASED_SEARCH tag is enabled the search engine will be -# implemented using a web server instead of a web client using Javascript. There +# implemented using a web server instead of a web client using JavaScript. There # are two flavors of web server based searching depending on the EXTERNAL_SEARCH # setting. When disabled, doxygen will generate a PHP script for searching and # an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing @@ -1591,7 +1870,8 @@ SERVER_BASED_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). +# Xapian (see: +# https://xapian.org/). # # See the section "External Indexing and Searching" for details. # The default value is: NO. @@ -1604,8 +1884,9 @@ EXTERNAL_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). See the section "External Indexing and -# Searching" for details. +# Xapian (see: +# https://xapian.org/). See the section "External Indexing and Searching" for +# details. # This tag requires that the tag SEARCHENGINE is set to YES. SEARCHENGINE_URL = @@ -1656,21 +1937,35 @@ LATEX_OUTPUT = latex # The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be # invoked. # -# Note that when enabling USE_PDFLATEX this option is only used for generating -# bitmaps for formulas in the HTML output, but not in the Makefile that is -# written to the output directory. -# The default file is: latex. +# Note that when not enabling USE_PDFLATEX the default is latex when enabling +# USE_PDFLATEX the default is pdflatex and when in the later case latex is +# chosen this is overwritten by pdflatex. For specific output languages the +# default can have been set differently, this depends on the implementation of +# the output language. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_CMD_NAME = latex # The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate # index for LaTeX. +# Note: This tag is used in the Makefile / make.bat. +# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file +# (.tex). # The default file is: makeindex. # This tag requires that the tag GENERATE_LATEX is set to YES. MAKEINDEX_CMD_NAME = makeindex +# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to +# generate index for LaTeX. In case there is no backslash (\) as first character +# it will be automatically added in the LaTeX code. +# Note: This tag is used in the generated output file (.tex). +# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat. +# The default value is: makeindex. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_MAKEINDEX_CMD = makeindex + # If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX # documents. This may be useful for small projects and may help to save some # trees in general. @@ -1700,29 +1995,31 @@ PAPER_TYPE = a4 EXTRA_PACKAGES = -# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the -# generated LaTeX document. The header should contain everything until the first -# chapter. If it is left blank doxygen will generate a standard header. See -# section "Doxygen usage" for information on how to let doxygen write the -# default header to a separate file. +# The LATEX_HEADER tag can be used to specify a user-defined LaTeX header for +# the generated LaTeX document. The header should contain everything until the +# first chapter. If it is left blank doxygen will generate a standard header. It +# is highly recommended to start with a default header using +# doxygen -w latex new_header.tex new_footer.tex new_stylesheet.sty +# and then modify the file new_header.tex. See also section "Doxygen usage" for +# information on how to generate the default header that doxygen normally uses. # -# Note: Only use a user-defined header if you know what you are doing! The -# following commands have a special meaning inside the header: $title, -# $datetime, $date, $doxygenversion, $projectname, $projectnumber, -# $projectbrief, $projectlogo. Doxygen will replace $title with the empty -# string, for the replacement values of the other commands the user is referred -# to HTML_HEADER. +# Note: Only use a user-defined header if you know what you are doing! +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. The following +# commands have a special meaning inside the header (and footer): For a +# description of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_HEADER = -# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the -# generated LaTeX document. The footer should contain everything after the last -# chapter. If it is left blank doxygen will generate a standard footer. See +# The LATEX_FOOTER tag can be used to specify a user-defined LaTeX footer for +# the generated LaTeX document. The footer should contain everything after the +# last chapter. If it is left blank doxygen will generate a standard footer. See # LATEX_HEADER for more information on how to generate a default footer and what -# special commands can be used inside the footer. -# -# Note: Only use a user-defined footer if you know what you are doing! +# special commands can be used inside the footer. See also section "Doxygen +# usage" for information on how to generate the default footer that doxygen +# normally uses. Note: Only use a user-defined footer if you know what you are +# doing! # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_FOOTER = @@ -1755,18 +2052,26 @@ LATEX_EXTRA_FILES = PDF_HYPERLINKS = YES -# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate -# the PDF file directly from the LaTeX files. Set this option to YES, to get a -# higher quality PDF documentation. +# If the USE_PDFLATEX tag is set to YES, doxygen will use the engine as +# specified with LATEX_CMD_NAME to generate the PDF file directly from the LaTeX +# files. Set this option to YES, to get a higher quality PDF documentation. +# +# See also section LATEX_CMD_NAME for selecting the engine. # The default value is: YES. # This tag requires that the tag GENERATE_LATEX is set to YES. USE_PDFLATEX = YES -# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode -# command to the generated LaTeX files. This will instruct LaTeX to keep running -# if errors occur, instead of asking the user for help. This option is also used -# when generating formulas in HTML. +# The LATEX_BATCHMODE tag ignals the behavior of LaTeX in case of an error. +# Possible values are: NO same as ERROR_STOP, YES same as BATCH, BATCH In batch +# mode nothing is printed on the terminal, errors are scrolled as if is +# hit at every error; missing files that TeX tries to input or request from +# keyboard input (\read on a not open input stream) cause the job to abort, +# NON_STOP In nonstop mode the diagnostic message will appear on the terminal, +# but there is no possibility of user interaction just like in batch mode, +# SCROLL In scroll mode, TeX will stop only for missing files to input or if +# keyboard input is necessary and ERROR_STOP In errorstop mode, TeX will stop at +# each error, asking for user intervention. # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. @@ -1779,24 +2084,22 @@ LATEX_BATCHMODE = NO LATEX_HIDE_INDICES = NO -# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source -# code with syntax highlighting in the LaTeX output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_SOURCE_CODE = NO - # The LATEX_BIB_STYLE tag can be used to specify the style to use for the # bibliography, e.g. plainnat, or ieeetr. See -# http://en.wikipedia.org/wiki/BibTeX and \cite for more info. +# https://en.wikipedia.org/wiki/BibTeX and \cite for more info. # The default value is: plain. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_BIB_STYLE = plain +# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) +# path from which the emoji images will be read. If a relative path is entered, +# it will be relative to the LATEX_OUTPUT directory. If left blank the +# LATEX_OUTPUT directory will be used. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_EMOJI_DIRECTORY = + #--------------------------------------------------------------------------- # Configuration options related to the RTF output #--------------------------------------------------------------------------- @@ -1836,9 +2139,9 @@ COMPACT_RTF = NO RTF_HYPERLINKS = NO -# Load stylesheet definitions from file. Syntax is similar to doxygen's config -# file, i.e. a series of assignments. You only have to provide replacements, -# missing definitions are set to their default value. +# Load stylesheet definitions from file. Syntax is similar to doxygen's +# configuration file, i.e. a series of assignments. You only have to provide +# replacements, missing definitions are set to their default value. # # See also section "Doxygen usage" for information on how to generate the # default style sheet that doxygen normally uses. @@ -1847,22 +2150,12 @@ RTF_HYPERLINKS = NO RTF_STYLESHEET_FILE = # Set optional variables used in the generation of an RTF document. Syntax is -# similar to doxygen's config file. A template extensions file can be generated -# using doxygen -e rtf extensionFile. +# similar to doxygen's configuration file. A template extensions file can be +# generated using doxygen -e rtf extensionFile. # This tag requires that the tag GENERATE_RTF is set to YES. RTF_EXTENSIONS_FILE = -# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code -# with syntax highlighting in the RTF output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_RTF is set to YES. - -RTF_SOURCE_CODE = NO - #--------------------------------------------------------------------------- # Configuration options related to the man page output #--------------------------------------------------------------------------- @@ -1934,6 +2227,13 @@ XML_OUTPUT = xml XML_PROGRAMLISTING = YES +# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include +# namespace members in file scope as well, matching the HTML output. +# The default value is: NO. +# This tag requires that the tag GENERATE_XML is set to YES. + +XML_NS_MEMB_FILE_SCOPE = NO + #--------------------------------------------------------------------------- # Configuration options related to the DOCBOOK output #--------------------------------------------------------------------------- @@ -1952,23 +2252,14 @@ GENERATE_DOCBOOK = NO DOCBOOK_OUTPUT = docbook -# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the -# program listings (including syntax highlighting and cross-referencing -# information) to the DOCBOOK output. Note that enabling this will significantly -# increase the size of the DOCBOOK output. -# The default value is: NO. -# This tag requires that the tag GENERATE_DOCBOOK is set to YES. - -DOCBOOK_PROGRAMLISTING = NO - #--------------------------------------------------------------------------- # Configuration options for the AutoGen Definitions output #--------------------------------------------------------------------------- # If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an -# AutoGen Definitions (see http://autogen.sf.net) file that captures the -# structure of the code including all documentation. Note that this feature is -# still experimental and incomplete at the moment. +# AutoGen Definitions (see https://autogen.sourceforge.net/) file that captures +# the structure of the code including all documentation. Note that this feature +# is still experimental and incomplete at the moment. # The default value is: NO. GENERATE_AUTOGEN_DEF = NO @@ -2047,7 +2338,8 @@ SEARCH_INCLUDES = NO # The INCLUDE_PATH tag can be used to specify one or more directories that # contain include files that are not input files but should be processed by the -# preprocessor. +# preprocessor. Note that the INCLUDE_PATH is not recursive, so the setting of +# RECURSIVE has no effect here. # This tag requires that the tag SEARCH_INCLUDES is set to YES. INCLUDE_PATH = @@ -2136,41 +2428,10 @@ EXTERNAL_GROUPS = YES EXTERNAL_PAGES = YES -# The PERL_PATH should be the absolute path and name of the perl script -# interpreter (i.e. the result of 'which perl'). -# The default file (with absolute path) is: /usr/bin/perl. - -PERL_PATH = /usr/bin/perl - #--------------------------------------------------------------------------- -# Configuration options related to the dot tool +# Configuration options related to diagram generator tools #--------------------------------------------------------------------------- -# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram -# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to -# NO turns the diagrams off. Note that this option also works with HAVE_DOT -# disabled, but it is recommended to install and use dot, since it yields more -# powerful graphs. -# The default value is: YES. - -CLASS_DIAGRAMS = NO - -# You can define message sequence charts within doxygen comments using the \msc -# command. Doxygen will then run the mscgen tool (see: -# http://www.mcternan.me.uk/mscgen/)) to produce the chart and insert it in the -# documentation. The MSCGEN_PATH tag allows you to specify the directory where -# the mscgen tool resides. If left empty the tool is assumed to be found in the -# default search path. - -MSCGEN_PATH = - -# You can include diagrams made with dia in doxygen documentation. Doxygen will -# then run dia to produce the diagram and insert it in the documentation. The -# DIA_PATH tag allows you to specify the directory where the dia binary resides. -# If left empty dia is assumed to be found in the default search path. - -DIA_PATH = - # If set to YES the inheritance and collaboration graphs will hide inheritance # and usage relations if the target is undocumented or is not a class. # The default value is: YES. @@ -2179,7 +2440,7 @@ HIDE_UNDOC_RELATIONS = YES # If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is # available from the path. This tool is part of Graphviz (see: -# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent +# https://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent # Bell Labs. The other options in this section have no effect if this option is # set to NO # The default value is: NO. @@ -2196,35 +2457,52 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. +# DOT_COMMON_ATTR is common attributes for nodes, edges and labels of +# subgraphs. When you want a differently looking font in the dot files that +# doxygen generates you can specify fontname, fontcolor and fontsize attributes. +# For details please see Node, +# Edge and Graph Attributes specification You need to make sure dot is able +# to find the font, which can be done by putting it in a standard location or by +# setting the DOTFONTPATH environment variable or by setting DOT_FONTPATH to the +# directory containing the font. Default graphviz fontsize is 14. +# The default value is: fontname=Helvetica,fontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTNAME = Helvetica +DOT_COMMON_ATTR = "fontname=Helvetica,fontsize=10" -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. +# DOT_EDGE_ATTR is concatenated with DOT_COMMON_ATTR. For elegant style you can +# add 'arrowhead=open, arrowtail=open, arrowsize=0.5'. Complete documentation about +# arrows shapes. +# The default value is: labelfontname=Helvetica,labelfontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTSIZE = 10 +DOT_EDGE_ATTR = "labelfontname=Helvetica,labelfontsize=10" -# By default doxygen will tell dot to use the default font as specified with -# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set -# the path where dot can find it using this tag. +# DOT_NODE_ATTR is concatenated with DOT_COMMON_ATTR. For view without boxes +# around nodes set 'shape=plain' or 'shape=plaintext' Shapes specification +# The default value is: shape=box,height=0.2,width=0.4. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_NODE_ATTR = "shape=box,height=0.2,width=0.4" + +# You can set the path where dot can find font specified with fontname in +# DOT_COMMON_ATTR and others dot attributes. # This tag requires that the tag HAVE_DOT is set to YES. DOT_FONTPATH = -# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for -# each documented class showing the direct and indirect inheritance relations. -# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO. +# If the CLASS_GRAPH tag is set to YES or GRAPH or BUILTIN then doxygen will +# generate a graph for each documented class showing the direct and indirect +# inheritance relations. In case the CLASS_GRAPH tag is set to YES or GRAPH and +# HAVE_DOT is enabled as well, then dot will be used to draw the graph. In case +# the CLASS_GRAPH tag is set to YES and HAVE_DOT is disabled or if the +# CLASS_GRAPH tag is set to BUILTIN, then the built-in generator will be used. +# If the CLASS_GRAPH tag is set to TEXT the direct and indirect inheritance +# relations will be shown as texts / links. +# Possible values are: NO, YES, TEXT, GRAPH and BUILTIN. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. CLASS_GRAPH = YES @@ -2238,7 +2516,8 @@ CLASS_GRAPH = YES COLLABORATION_GRAPH = YES # If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for -# groups, showing the direct groups dependencies. +# groups, showing the direct groups dependencies. See also the chapter Grouping +# in the manual. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2261,10 +2540,32 @@ UML_LOOK = NO # but if the number exceeds 15, the total amount of fields shown is limited to # 10. # Minimum value: 0, maximum value: 100, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. +# This tag requires that the tag UML_LOOK is set to YES. UML_LIMIT_NUM_FIELDS = 10 +# If the DOT_UML_DETAILS tag is set to NO, doxygen will show attributes and +# methods without types and arguments in the UML graphs. If the DOT_UML_DETAILS +# tag is set to YES, doxygen will add type and arguments for attributes and +# methods in the UML graphs. If the DOT_UML_DETAILS tag is set to NONE, doxygen +# will not generate fields with class member information in the UML graphs. The +# class diagrams will look similar to the default class diagrams but using UML +# notation for the relationships. +# Possible values are: NO, YES and NONE. +# The default value is: NO. +# This tag requires that the tag UML_LOOK is set to YES. + +DOT_UML_DETAILS = NO + +# The DOT_WRAP_THRESHOLD tag can be used to set the maximum number of characters +# to display on a single line. If the actual line length exceeds this threshold +# significantly it will wrapped across multiple lines. Some heuristics are apply +# to avoid ugly line breaks. +# Minimum value: 0, maximum value: 1000, default value: 17. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_WRAP_THRESHOLD = 17 + # If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and # collaboration graphs will show the relations between templates and their # instances. @@ -2331,10 +2632,17 @@ GRAPHICAL_HIERARCHY = YES DIRECTORY_GRAPH = YES +# The DIR_GRAPH_MAX_DEPTH tag can be used to limit the maximum number of levels +# of child directories generated in directory dependency graphs by dot. +# Minimum value: 1, maximum value: 25, default value: 1. +# This tag requires that the tag DIRECTORY_GRAPH is set to YES. + +DIR_GRAPH_MAX_DEPTH = 1 + # The DOT_IMAGE_FORMAT tag can be used to set the image format of the images # generated by dot. For an explanation of the image formats see the section # output formats in the documentation of the dot tool (Graphviz (see: -# http://www.graphviz.org/)). +# https://www.graphviz.org/)). # Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order # to make the SVG files visible in IE 9+ (other browsers do not have this # requirement). @@ -2371,11 +2679,12 @@ DOT_PATH = DOTFILE_DIRS = -# The MSCFILE_DIRS tag can be used to specify one or more directories that -# contain msc files that are included in the documentation (see the \mscfile -# command). +# You can include diagrams made with dia in doxygen documentation. Doxygen will +# then run dia to produce the diagram and insert it in the documentation. The +# DIA_PATH tag allows you to specify the directory where the dia binary resides. +# If left empty dia is assumed to be found in the default search path. -MSCFILE_DIRS = +DIA_PATH = # The DIAFILE_DIRS tag can be used to specify one or more directories that # contain dia files that are included in the documentation (see the \diafile @@ -2384,13 +2693,18 @@ MSCFILE_DIRS = DIAFILE_DIRS = # When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the -# path where java can find the plantuml.jar file. If left blank, it is assumed -# PlantUML is not used or called during a preprocessing step. Doxygen will -# generate a warning when it encounters a \startuml command in this case and -# will not generate output for the diagram. +# path where java can find the plantuml.jar file or to the filename of jar file +# to be used. If left blank, it is assumed PlantUML is not used or called during +# a preprocessing step. Doxygen will generate a warning when it encounters a +# \startuml command in this case and will not generate output for the diagram. PLANTUML_JAR_PATH = +# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a +# configuration file for plantuml. + +PLANTUML_CFG_FILE = + # When using plantuml, the specified paths are searched for files specified by # the !include statement in a plantuml block. @@ -2420,18 +2734,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support @@ -2444,14 +2746,34 @@ DOT_MULTI_TARGETS = NO # If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page # explaining the meaning of the various boxes and arrows in the dot generated # graphs. +# Note: This tag requires that UML_LOOK isn't set, i.e. the doxygen internal +# graphical representation for inheritance and collaboration diagrams is used. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. GENERATE_LEGEND = YES -# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot +# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate # files that are used to generate the various graphs. +# +# Note: This setting is not only used for dot files but also for msc temporary +# files. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. DOT_CLEANUP = YES + +# You can define message sequence charts within doxygen comments using the \msc +# command. If the MSCGEN_TOOL tag is left empty (the default), then doxygen will +# use a built-in version of mscgen tool to produce the charts. Alternatively, +# the MSCGEN_TOOL tag can also specify the name an external tool. For instance, +# specifying prog as the value, doxygen will call the tool as prog -T +# -o . The external tool should support +# output file formats "png", "eps", "svg", and "ismap". + +MSCGEN_TOOL = + +# The MSCFILE_DIRS tag can be used to specify one or more directories that +# contain msc files that are included in the documentation (see the \mscfile +# command). + +MSCFILE_DIRS = diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b89cb9fec8..ac03e40939 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.18.2 +rocm-docs-core[api_reference]==1.18.2 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 2a52a48e4c..3742eeebba 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -6,68 +6,79 @@ # accessible-pygments==0.0.5 # via pydata-sphinx-theme -alabaster==0.7.16 +alabaster==1.0.0 # via sphinx asttokens==3.0.0 # via stack-data -attrs==24.3.0 +attrs==25.3.0 # via # jsonschema # jupyter-cache # referencing -babel==2.15.0 +babel==2.17.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.12.3 +beautifulsoup4==4.13.4 # via pydata-sphinx-theme -breathe==4.35.0 +breathe==4.36.0 # via rocm-docs-core -certifi==2024.7.4 +certifi==2025.1.31 # via requests -cffi==1.16.0 +cffi==1.17.1 # via # cryptography # pynacl -charset-normalizer==3.3.2 +charset-normalizer==3.4.1 # via requests -click==8.1.7 +click==8.1.8 # via + # click-log + # doxysphinx # jupyter-cache # sphinx-external-toc +click-log==0.4.0 + # via doxysphinx comm==0.2.2 # via ipykernel -cryptography==43.0.0 +contourpy==1.3.2 + # via matplotlib +cryptography==44.0.2 # via pyjwt -debugpy==1.8.12 +cycler==0.12.1 + # via matplotlib +debugpy==1.8.14 # via ipykernel -decorator==5.1.1 +decorator==5.2.1 # via ipython -deprecated==1.2.14 +deprecated==1.2.18 # via pygithub docutils==0.21.2 # via - # breathe # myst-parser # pybtex-docutils # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +doxysphinx==3.3.12 + # via rocm-docs-core exceptiongroup==1.2.2 # via ipython -executing==2.1.0 +executing==2.2.0 # via stack-data -fastjsonschema==2.20.0 +fastjsonschema==2.21.1 # via # nbformat # rocm-docs-core -gitdb==4.0.11 +fonttools==4.57.0 + # via matplotlib +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via rocm-docs-core -greenlet==3.1.1 +greenlet==3.2.1 # via sqlalchemy -idna==3.7 +idna==3.10 # via requests imagesize==1.4.1 # via sphinx @@ -77,13 +88,13 @@ importlib-metadata==8.6.1 # myst-nb ipykernel==6.29.5 # via myst-nb -ipython==8.31.0 +ipython==8.35.0 # via # ipykernel # myst-nb jedi==0.19.2 # via ipython -jinja2==3.1.4 +jinja2==3.1.6 # via # myst-parser # sphinx @@ -103,25 +114,35 @@ jupyter-core==5.7.2 # jupyter-client # nbclient # nbformat +kiwisolver==1.4.8 + # via matplotlib latexcodec==3.0.0 # via pybtex +libsass==0.22.0 + # via doxysphinx +lxml==5.2.1 + # via doxysphinx markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 +matplotlib==3.10.1 + # via doxysphinx matplotlib-inline==0.1.7 # via # ipykernel # ipython -mdit-py-plugins==0.4.1 +mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-nb==1.1.2 +mpire==2.10.2 + # via doxysphinx +myst-nb==1.2.0 # via rocm-docs-core -myst-parser==3.0.1 +myst-parser==4.0.1 # via myst-nb nbclient==0.10.2 # via @@ -134,20 +155,28 @@ nbformat==5.10.4 # nbclient nest-asyncio==1.6.0 # via ipykernel -packaging==24.1 +numpy==1.26.4 + # via + # contourpy + # doxysphinx + # matplotlib +packaging==25.0 # via # ipykernel + # matplotlib # pydata-sphinx-theme # sphinx parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.3.6 +pillow==11.2.1 + # via matplotlib +platformdirs==4.3.7 # via jupyter-core -prompt-toolkit==3.0.50 +prompt-toolkit==3.0.51 # via ipython -psutil==6.1.1 +psutil==7.0.0 # via ipykernel ptyprocess==0.7.0 # via pexpect @@ -165,21 +194,30 @@ pydata-sphinx-theme==0.15.4 # via # rocm-docs-core # sphinx-book-theme -pygithub==2.3.0 +pygithub==2.6.1 # via rocm-docs-core -pygments==2.18.0 +pygments==2.19.1 # via # accessible-pygments # ipython + # mpire # pydata-sphinx-theme # sphinx -pyjwt[crypto]==2.8.0 +pyjson5==1.6.8 + # via doxysphinx +pyjwt[crypto]==2.10.1 # via pygithub pynacl==1.5.0 # via pygithub +pyparsing==3.2.3 + # via + # doxysphinx + # matplotlib python-dateutil==2.9.0.post0 - # via jupyter-client -pyyaml==6.0.1 + # via + # jupyter-client + # matplotlib +pyyaml==6.0.2 # via # jupyter-cache # myst-nb @@ -187,11 +225,11 @@ pyyaml==6.0.1 # pybtex # rocm-docs-core # sphinx-external-toc -pyzmq==26.2.0 +pyzmq==26.4.0 # via # ipykernel # jupyter-client -referencing==0.36.1 +referencing==0.36.2 # via # jsonschema # jsonschema-specifications @@ -199,23 +237,23 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.18.2 +rocm-docs-core[api-reference]==1.18.2 # via -r requirements.in -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -six==1.16.0 +six==1.17.0 # via # pybtex # python-dateutil -smmap==5.0.1 +smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 # via sphinx -soupsieve==2.5 +soupsieve==2.7 # via beautifulsoup4 -sphinx==7.4.7 +sphinx==8.1.3 # via # breathe # myst-nb @@ -228,15 +266,15 @@ sphinx==7.4.7 # sphinx-external-toc # sphinx-notfound-page # sphinxcontrib-bibtex -sphinx-book-theme==1.1.3 +sphinx-book-theme==1.1.4 # via rocm-docs-core sphinx-copybutton==0.5.2 # via rocm-docs-core -sphinx-design==0.6.0 +sphinx-design==0.6.1 # via rocm-docs-core sphinx-external-toc==1.0.1 # via rocm-docs-core -sphinx-notfound-page==1.0.3 +sphinx-notfound-page==1.1.0 # via rocm-docs-core sphinxcontrib-applehelp==2.0.0 # via sphinx @@ -252,18 +290,20 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sqlalchemy==2.0.37 +sqlalchemy==2.0.40 # via jupyter-cache stack-data==0.6.3 # via ipython tabulate==0.9.0 # via jupyter-cache -tomli==2.0.1 +tomli==2.2.1 # via sphinx tornado==6.4.2 # via # ipykernel # jupyter-client +tqdm==4.67.1 + # via mpire traitlets==5.14.3 # via # comm @@ -274,21 +314,22 @@ traitlets==5.14.3 # matplotlib-inline # nbclient # nbformat -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via + # beautifulsoup4 # ipython # myst-nb # pydata-sphinx-theme # pygithub # referencing # sqlalchemy -urllib3==2.2.2 +urllib3==2.4.0 # via # pygithub # requests wcwidth==0.2.13 # via prompt-toolkit -wrapt==1.16.0 +wrapt==1.17.2 # via deprecated zipp==3.21.0 # via importlib-metadata From 01cb8379cd9b7ce401085e60b39abde50e7dc734 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 24 Apr 2025 10:14:52 -0700 Subject: [PATCH 068/443] make code compliant with std=c++20 (#2123) --- include/ck/library/utility/fill.hpp | 4 ++-- include/ck/library/utility/host_tensor.hpp | 2 +- include/ck_tile/host/fill.hpp | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 3336041354..35625d142e 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -94,7 +94,7 @@ struct FillMonotonicSeq template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = init_value_]() mutable { + std::generate(first, last, [=, *this, n = init_value_]() mutable { auto tmp = n; n += step_; return tmp; @@ -150,7 +150,7 @@ struct TransformIntoStructuralSparsity template void operator()(ForwardIter first, ForwardIter last) const { - std::for_each(first, last, [=, idx = 0](T& elem) mutable { + std::for_each(first, last, [=, *this, idx = 0](T& elem) mutable { auto tmp_idx = idx; idx += 1; return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index edf58b20b4..2cbca29afc 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -252,7 +252,7 @@ struct ParallelTensorFunctor std::size_t iw_begin = it * work_per_thread; std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); - auto f = [=] { + auto f = [=, *this] { for(std::size_t iw = iw_begin; iw < iw_end; ++iw) { call_f_unpack_args(mF, GetNdIndices(iw)); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index d90c0cf6cf..3f64eb28cd 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -280,7 +280,7 @@ struct FillMonotonicSeq template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = init_value_]() mutable { + std::generate(first, last, [=, *this, n = init_value_]() mutable { auto tmp = n; if constexpr(std::is_same_v) { @@ -315,7 +315,7 @@ struct FillStepRange template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = start_value_]() mutable { + std::generate(first, last, [=, *this, n = start_value_]() mutable { auto tmp = n; n += step_; if constexpr(IsAscending) @@ -388,7 +388,7 @@ struct AdjustToStructuredSparsity template void operator()(ForwardIter first, ForwardIter last) const { - std::transform(first, last, first, [=, index = start](T val) mutable { + std::transform(first, last, first, [=, *this, index = start](T val) mutable { auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))]; index += 1; From a2ed34a112982664132db5283ee4d1b1aac746d5 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 24 Apr 2025 10:20:22 -0700 Subject: [PATCH 069/443] MFMA_32x32x16 for gfx950 (#2121) * Enable MFMA_32x32x16 for fp16/BF16 for gfx950 * clang formatted --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e6350a8827..4732027e57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -20,9 +20,15 @@ using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< @@ -105,9 +111,15 @@ using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif #if defined(__gfx950__) using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< From 41541aff7a3651b72977d3c52786a37bba24a7d2 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Fri, 25 Apr 2025 16:31:09 +0800 Subject: [PATCH 070/443] SWDEV-52596 for hdim=256, when use splitkv pipeline, two new pipelines need to be added (#2126) --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ca49af1496..75d84daf32 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -676,6 +676,12 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) else: From 3d4d70d2fc6b1fe77d82e3cd2b5c9aae3a315b42 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sun, 27 Apr 2025 14:07:41 +0800 Subject: [PATCH 071/443] Avoid using store_tile_raw() for fp32 tensors (#2072) --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 75d84daf32..5ad118fd1a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -91,10 +91,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = ck_tile::FmhaFwdSplitKVKernel; From 8add2cf45d8c9b298d820c6cf7f158cc13936352 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 28 Apr 2025 07:26:05 +0800 Subject: [PATCH 072/443] Fix fp8 convert & add option for basic example (#2129) --- example/ck_tile/03_gemm/CMakeLists.txt | 1 + include/ck_tile/core/numeric/float8.hpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 61c3a57391..411db2e317 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -5,4 +5,5 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index a4e8ca6a2b..b5da468319 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -530,7 +530,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) } else { - if(x == 0x80) + if(x == SrcT(0x80)) { return fNeg0; } From edd92fc546663094f42366e12a172701f18a2fd9 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Mon, 28 Apr 2025 11:14:21 +0600 Subject: [PATCH 073/443] DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096) * Prepare files for DeviceGemm_Wmma_CShuffleV3 * Implement main part of CShuffleV3 with block pipeline v3 for WMMA * Remove unused functions and template params for A/B descriptors * Support both gfx11 and gfx12 * Enable SplitK for gfx12 and disable for gfx11 * Added RowColRow layout for DeviceGemmV2 fp16 * Added more instances for Row, Col, Row data layout * Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout * Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout * Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout * Fix formatting * Add documentation Based on e5ad48a7843a16a1ed0c1268b5dba7dfe2d59e4d * Enable gemm_universal profiling for gfx11/12 * Add WMMA intrinsics for F8/BF8 * Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances * Add BF16 instances and tests * Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8 --------- Co-authored-by: Anca Hamuraru --- CMakeLists.txt | 7 +- include/ck/ck.hpp | 2 +- include/ck/config.h.in | 6 +- .../blockwise_gemm_pipeline_wmma_selector.hpp | 60 + .../block/blockwise_gemm_pipeline_wmmaops.hpp | 85 + .../blockwise_gemm_pipeline_wmmaops_base.hpp | 309 +++ .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 466 +++++ .../impl/device_gemm_wmma_cshuffle_v3.hpp | 542 ++++++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1725 +++++++++++++++++ .../tensor_operation/gpu/warp/wmma_gemm.hpp | 184 +- include/ck/utility/amd_buffer_addressing.hpp | 2 +- include/ck/utility/amd_wmma.hpp | 98 +- .../gpu/gemm_universal.hpp | 599 +----- .../gpu/gemm_universal_wmma.inc | 68 + .../gpu/gemm_universal_xdl.inc | 521 +++++ .../gpu/CMakeLists.txt | 38 +- .../gpu/gemm_universal/CMakeLists.txt | 68 +- ...wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 64 + ...16_bf16_km_kn_mn_comp_default_instance.cpp | 25 + ...wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 64 + ...16_bf16_km_nk_mn_comp_default_instance.cpp | 25 + ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 67 + ...16_bf16_mk_kn_mn_comp_default_instance.cpp | 25 + ...wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 64 + ...16_bf16_mk_nk_mn_comp_default_instance.cpp | 25 + ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 64 + ...f16_f16_km_kn_mn_comp_default_instance.cpp | 24 + ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 64 + ...f16_f16_km_nk_mn_comp_default_instance.cpp | 24 + ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 67 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 24 + ...mm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp | 64 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 24 + ...emm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp | 51 + ...f8_bf16_mk_kn_mn_comp_default_instance.cpp | 27 + ...emm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp | 51 + ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 27 + .../profiler/profile_gemm_universal_impl.hpp | 2 +- profiler/src/CMakeLists.txt | 4 +- profiler/src/profile_gemm_universal.cpp | 10 +- test/gemm_universal/CMakeLists.txt | 32 +- .../test_gemm_universal_wmma_bf16.cpp | 80 + .../test_gemm_universal_wmma_fp16.cpp | 57 + .../test_gemm_universal_wmma_fp8.cpp | 61 + 44 files changed, 5326 insertions(+), 570 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp create mode 100644 test/gemm_universal/test_gemm_universal_wmma_bf16.cpp create mode 100644 test/gemm_universal/test_gemm_universal_wmma_fp16.cpp create mode 100644 test/gemm_universal/test_gemm_universal_wmma_fp8.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ba57ead09a..4e12462a41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,7 +202,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") set(CK_USE_XDL "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") - message("Enabling FP8 gemms on native architectures") + message("Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() @@ -211,6 +211,11 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1 add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message("Enabling WMMA FP8 gemms on native architectures") + add_definitions(-DCK_USE_WMMA_FP8) + set(CK_USE_WMMA_FP8 "ON") +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 83b76382bc..e38f166c1a 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -125,7 +125,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx9__) // for GPU code +#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 994e60025d..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -115,6 +115,10 @@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif +#ifndef CK_USE_WMMA_FP8 +#cmakedefine CK_USE_WMMA_FP8 @CK_USE_WMMA_FP8@ +#endif + #ifndef CK_USE_GFX94 #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp new file mode 100644 index 0000000000..2fdabc6bc7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmWmmaops_pipeline_v3{}; + } + else + { + static_assert(false, "BlockGemmPipeline configuration is not available"); + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp new file mode 100644 index 0000000000..31c4729760 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 32; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerWmma * NPerWmma * KPerWmma); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + KPerWmma); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C WMMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d, %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_WMMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp new file mode 100644 index 0000000000..a63d32802e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I5 = Number<5>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = 32; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + +#if defined(__gfx12__) + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; +#else + static constexpr index_t A_KRow = 1; + static constexpr index_t B_KRow = 1; +#endif + + static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + + static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); + static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t KRepeat = KPerBlock / KPack; + + static constexpr auto WmmaK = Number{}; + + using HotLoopInstList = + ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst; + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding WMMA and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && + BWmmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 && + NPerBlock % (NPerWmma * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1; + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWmma] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp new file mode 100644 index 0000000000..2fb95f0f8d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmWmmaops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v3 + : BlockwiseGemmWmmaops_pipeline_base +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Calculation of the number of instructions may require changes for WMMA + /* + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num; + + constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_wmma = + (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate; + constexpr auto num_dsread_b_wmma = + (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma); + constexpr auto num_wmma_per_issue = + num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA + }); + + // stage 2 + static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >= + ds_read_a_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_wmma - 1) * + ds_read_a_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + + static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >= + ds_read_b_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_wmma - 1) * + ds_read_b_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + */ + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..1ef8a9b8ad --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // TODO: Implement + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With universal GEMM +/// there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct GridwiseGemm_wmma_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_N_K1 -> K0_MNRepeat_MNWaves_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 1abae56be4..429df2413f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,10 @@ enum struct WmmaInstr wmma_f32_16x16x16_f16_gfx12, wmma_f32_16x16x16_bf16_gfx12, wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, }; /* @@ -400,6 +404,146 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + template + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; + } + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround static constexpr auto selected_wmma = wmma_type(), Number<32>{}>{}; @@ -612,14 +781,17 @@ struct WmmaGemm (is_same::value && is_same::value && is_same::value) || (is_same::value && is_same::value && - is_same::value) + is_same::value) || + ((is_same::value || is_same::value) && + (is_same::value || is_same::value) && + is_same::value) || #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value && - is_same::value) + (is_same::value && is_same::value && + is_same::value) || #endif - , + false, "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " - "(int8, int32) or (int4, int32)!"); + "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!"); static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { if constexpr(!TransposeC) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 317f324e6d..62e3220b5a 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index aa519fb2be..e14c0d62a8 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP @@ -341,5 +341,101 @@ struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> } }; +// src: f8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp index 4218c51ca3..79212e16dd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,521 +7,22 @@ #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_WMMA +#include "gemm_universal_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_universal_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif template > op_ptrs; +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + op_ptrs); + } + } +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA + +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -822,6 +399,7 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -831,7 +409,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); } } - +#endif +#ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -842,6 +421,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#endif // CK_USE_XDL return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc new file mode 100644 index 0000000000..1396437326 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc new file mode 100644 index 0000000000..f0de713834 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 70e54962ed..fe35d9ca76 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -81,21 +81,29 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") + message("removing gemm_multiply_multiply_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") - message("removing gemm_multiply_multiply_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() + if(NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "gemm_wmma_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endforeach() - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") - message("removing gemm_universal_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - endif() + message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) set(INST_OBJ) @@ -124,7 +132,10 @@ function(add_instance_library INSTANCE_NAME) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - endif() + endif() + endif() + if(source MATCHES "gemm_wmma_universal" AND source MATCHES "f8") + list(FILTER INST_TARGETS INCLUDE REGEX "gfx12") endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) @@ -455,4 +466,3 @@ set(DEV_OPS_INC_DIRS ${PROJECT_SOURCE_DIR}/library/include/ck/ ) rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) - diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index ade65eacf3..18eeefa522 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -1,7 +1,17 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) -list(APPEND GEMM_UNIVERSAL_INSTANCES +list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp + + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -18,7 +28,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp - + device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -57,6 +67,16 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp ) +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -80,6 +100,9 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -134,25 +157,28 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp ) - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp new file mode 100644 index 0000000000..5d3bb3f7b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..c9a730de68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp new file mode 100644 index 0000000000..6c3a641f9f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..cd88edec59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..b700e78d3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..9951c02251 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7b4cd64d33 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3a607c4178 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp new file mode 100644 index 0000000000..3751dc5a11 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3971802415 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp new file mode 100644 index 0000000000..222b49eb7d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..36901b4f38 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..6960375ed6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..bbc8b92217 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7f71cf6f59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..331ca8b2ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..2fca3551b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..5087a9d719 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..244eb69190 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..89df765517 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 2054ffbbb3..f7b1d5f1f8 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal.hpp" diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9cb70e4670..17c8c277eb 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,7 +58,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) @@ -76,6 +75,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() + list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) @@ -144,7 +144,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) @@ -170,6 +169,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index a22d983da5..7f2393a7e6 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -103,8 +103,10 @@ int profile_gemm_universal(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) using F8 = ck::f8_t; +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using I4 = ck::pk_i4_t; #endif @@ -201,7 +203,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); @@ -210,6 +212,8 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt index cf5c68e220..0a68622ebe 100755 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,15 +1,29 @@ -add_gtest_executable(test_gemm_universal_fp16 test_gemm_universal_xdl_fp16.cpp) +add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp16 PRIVATE utility device_gemm_universal_instance) - endif() - -add_gtest_executable(test_gemm_universal_fp8 test_gemm_universal_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_fp8 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance) endif() -add_gtest_executable(test_gemm_universal_bf16 test_gemm_universal_xdl_bf16.cpp) +add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal_bf16 PRIVATE utility device_gemm_universal_instance) + target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance) endif() +add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance) +endif() diff --git a/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp new file mode 100644 index 0000000000..22376a8599 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_BF16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); + +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp new file mode 100644 index 0000000000..1adee41ed2 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp new file mode 100644 index 0000000000..3579424496 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +#if CK_USE_WMMA_FP8 + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp8.inc" + +#endif // CK_USE_WMMA_FP8 From 83394e40d2452d32701bed4ed85bea1bfa50cfc2 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 29 Apr 2025 00:49:31 +0800 Subject: [PATCH 074/443] fix moe i4 example bug (#2139) --- example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 3c3ef16198..9e80a2ca35 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -233,7 +233,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; @@ -266,7 +266,8 @@ int main(int argc, char* argv[]) Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N * 2}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( From 434d19f696da62c12b5372b32cbc9ba968588d7e Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 28 Apr 2025 18:53:19 +0200 Subject: [PATCH 075/443] Add ck tile examples to package (#1880) * add ck tile examples to package * Update jenkinsfile * fix for jenkinsfile * fix for building ck tile code on non gfx9 * compile ck tile examples only for gfx94 * include ck tile examples in all target * fix for basic gemm UseStructuredSparsity * Update CMakeLists.txt * Update gemm_pipeline_problem.hpp * add targets to rocm install --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- example/CMakeLists.txt | 4 +++- example/ck_tile/01_fmha/CMakeLists.txt | 6 ++++-- example/ck_tile/02_layernorm2d/CMakeLists.txt | 3 ++- example/ck_tile/03_gemm/CMakeLists.txt | 7 +++++-- example/ck_tile/03_gemm/stript.sh | 1 + example/ck_tile/04_img2col/CMakeLists.txt | 3 ++- example/ck_tile/05_reduce/CMakeLists.txt | 4 +++- example/ck_tile/06_permute/CMakeLists.txt | 3 ++- .../ck_tile/09_topk_softmax/CMakeLists.txt | 5 +++-- example/ck_tile/10_rmsnorm2d/CMakeLists.txt | 6 ++++-- .../11_add_rmsnorm2d_rdquant/CMakeLists.txt | 6 ++++-- .../add_rmsnorm2d_rdquant_fwd.cpp | 21 +++++++++++-------- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 21 +++++++++++-------- example/ck_tile/12_smoothquant/CMakeLists.txt | 3 ++- example/ck_tile/13_moe_sorting/CMakeLists.txt | 3 ++- .../ck_tile/14_moe_smoothquant/CMakeLists.txt | 3 ++- example/ck_tile/15_fused_moe/CMakeLists.txt | 3 ++- .../ck_tile/16_batched_gemm/CMakeLists.txt | 3 ++- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 4 ++-- example/ck_tile/18_flatmm/CMakeLists.txt | 4 +++- .../35_batched_transpose/CMakeLists.txt | 4 ++-- example/ck_tile/CMakeLists.txt | 5 ++++- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 +-- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 +++-- 25 files changed, 83 insertions(+), 48 deletions(-) create mode 100644 example/ck_tile/03_gemm/stript.sh mode change 100644 => 100755 example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 996a543ecc..0e61fd33ef 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,7 +5,6 @@ include_directories(BEFORE add_custom_target(examples) - # list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds) # all other tests are labelled as SMOKE_EXAMPLE set(REGRESSION_EXAMPLES @@ -232,6 +231,9 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) +if (NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9") + list(FILTER dir_list EXCLUDE REGEX ".*/ck_tile") +endif() FOREACH(subdir ${dir_list}) if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9ba3a453fc..ce3c8b3978 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -58,7 +58,8 @@ set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message("adding example ${EXAMPLE_FMHA_FWD}") -add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) +add_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) +rocm_install(TARGETS ${EXAMPLE_FMHA_FWD} COMPONENT examples) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) @@ -66,7 +67,8 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message("adding example ${EXAMPLE_FMHA_BWD}") -add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) +add_executable(${EXAMPLE_FMHA_BWD} fmha_bwd.cpp) +rocm_install(TARGETS ${EXAMPLE_FMHA_BWD} COMPONENT examples) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index fa69ac0f7a..74f195a9db 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -26,7 +26,8 @@ add_custom_command( set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") -add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +add_executable(${EXAMPLE_LAYERNORM2D_FWD} layernorm2d_fwd.cpp) +rocm_install(TARGETS ${EXAMPLE_LAYERNORM2D_FWD} COMPONENT examples) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 411db2e317..deccb71d23 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,8 @@ -add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_basic gemm_basic.cpp) +rocm_install(TARGETS tile_example_gemm_basic COMPONENT examples) +add_executable(tile_example_gemm_universal universal_gemm.cpp) +rocm_install(TARGETS tile_example_gemm_universal COMPONENT examples) + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/03_gemm/stript.sh b/example/ck_tile/03_gemm/stript.sh new file mode 100644 index 0000000000..4b91cb36ce --- /dev/null +++ b/example/ck_tile/03_gemm/stript.sh @@ -0,0 +1 @@ +for file in gemm_universal_*; do mv "$file" "${file/f16_f16_f16/fp16_fp16_fp16}"; done diff --git a/example/ck_tile/04_img2col/CMakeLists.txt b/example/ck_tile/04_img2col/CMakeLists.txt index 3864c9ed9d..d3737467d8 100644 --- a/example/ck_tile/04_img2col/CMakeLists.txt +++ b/example/ck_tile/04_img2col/CMakeLists.txt @@ -1,3 +1,4 @@ # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) +add_executable(tile_example_img2col image_to_column.cpp) +rocm_install(TARGETS tile_example_img2col COMPONENT examples) diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 6caa38d50d..855e59c48e 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -3,7 +3,9 @@ set(EXAMPLE_REDUCE "tile_example_reduce") # to be included in "make all/install/check" message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp) +add_executable(${EXAMPLE_REDUCE} reduce.cpp) +rocm_install(TARGETS ${EXAMPLE_REDUCE} COMPONENT examples) + target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) diff --git a/example/ck_tile/06_permute/CMakeLists.txt b/example/ck_tile/06_permute/CMakeLists.txt index 327fceb685..22483a4295 100644 --- a/example/ck_tile/06_permute/CMakeLists.txt +++ b/example/ck_tile/06_permute/CMakeLists.txt @@ -1,6 +1,7 @@ # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp) +add_executable(tile_example_permute permute.cpp) +rocm_install(TARGETS tile_example_permute COMPONENT examples) if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) # set(PERMUTE_USE_ALTERNATIVE_IMPL false) diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt index b43b989792..fc2a4d3fe0 100644 --- a/example/ck_tile/09_topk_softmax/CMakeLists.txt +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -1,6 +1,7 @@ -add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp) -target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) +add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) +rocm_install(TARGETS tile_example_topk_softmax COMPONENT examples) +target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt index 5684c9b2e0..731ff639a4 100644 --- a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -26,7 +26,8 @@ add_custom_command( set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") message("adding ${TILE_RMSNORM2D_FWD}") -add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) +add_executable(${TILE_RMSNORM2D_FWD} rmsnorm2d_fwd.cpp) +rocm_install(TARGETS ${TILE_RMSNORM2D_FWD} COMPONENT examples) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) @@ -38,7 +39,8 @@ list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") -add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp) +add_executable(${EXAMPLE_RMSNORM2D_FWD} example_rmsnorm2d_fwd.cpp) +rocm_install(TARGETS ${EXAMPLE_RMSNORM2D_FWD} COMPONENT examples) target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt index 6b0c3cef7a..7071127e01 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -3,7 +3,8 @@ set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd") # to be included in "make all/install/check" message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) +add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp) +rocm_install(TARGETS ${TILE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples) target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) @@ -15,7 +16,8 @@ list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-t target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") -add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp) +rocm_install(TARGETS ${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples) target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 574edf64d3..7d82a16aa9 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -67,13 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = AddRmsnormRdquantTypeConfig; - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using GammaDataType = typename TypeConfig::GammaDataType; - using XDataType = typename TypeConfig::XDataType; - using YScaleDataType = typename TypeConfig::YScaleDataType; - using QYDataType = typename TypeConfig::QYDataType; - using ComputeDataType = float; + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using UnquantYDataType = ck_tile::null_type; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; // host verify ck_tile::HostTensor a_host({m, n}, {stride, 1}); @@ -88,6 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); @@ -191,8 +193,9 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType>( - x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + InvRmsDataType, + UnquantYDataType>( + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); } // yscale diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp old mode 100644 new mode 100755 index ada4c6f2da..3aab357909 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -62,13 +62,14 @@ bool run(const ck_tile::ArgParser& arg_parser) assert(stride >= n); - using ADataType = DataType; - using BDataType = DataType; - using GammaDataType = DataType; - using XDataType = DataType; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using ADataType = DataType; + using BDataType = DataType; + using GammaDataType = DataType; + using XDataType = DataType; + using UnquantYDataType = ck_tile::null_type; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; // host verify ck_tile::HostTensor a_host({m, n}, {stride, 1}); @@ -81,6 +82,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor yscale_host_dev({m}, {1}); ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); @@ -193,8 +195,9 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType>( - x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + InvRmsDataType, + UnquantYDataType>( + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); } // yscale diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt index 3849833aca..daeeb827bd 100644 --- a/example/ck_tile/12_smoothquant/CMakeLists.txt +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -2,7 +2,8 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) message("adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + add_executable(${TARGET_NAME} ${MAIN_SRC}) + rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt index 09f3e4ac4e..662e16f0d3 100644 --- a/example/ck_tile/13_moe_sorting/CMakeLists.txt +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -1,4 +1,5 @@ -add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) +add_executable(tile_example_moe_sorting moe_sorting.cpp moe_sorting_api.cpp) +rocm_install(TARGETS tile_example_moe_sorting COMPONENT examples) target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) diff --git a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt index 12224a39a2..9acb27552a 100644 --- a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt +++ b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt @@ -2,7 +2,8 @@ function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC) message("adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + add_executable(${TARGET_NAME} ${MAIN_SRC}) + rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt index a716eef19e..bb25a55c7d 100644 --- a/example/ck_tile/15_fused_moe/CMakeLists.txt +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -3,7 +3,8 @@ set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") # to be included in "make all/install/check" message("adding ${TILE_EXAPMLE_FUSED_MOE}") file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) +add_executable(${TILE_EXAPMLE_FUSED_MOE} main.cpp) +rocm_install(TARGETS ${TILE_EXAPMLE_FUSED_MOE} COMPONENT examples) target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) diff --git a/example/ck_tile/16_batched_gemm/CMakeLists.txt b/example/ck_tile/16_batched_gemm/CMakeLists.txt index 78e78c6b04..9eb7a45d80 100644 --- a/example/ck_tile/16_batched_gemm/CMakeLists.txt +++ b/example/ck_tile/16_batched_gemm/CMakeLists.txt @@ -1 +1,2 @@ -add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp) +add_executable(tile_example_batched_gemm batched_gemm.cpp) +rocm_install(TARGETS tile_example_batched_gemm COMPONENT examples) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index d34013dd6c..80d688125b 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ -add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) - +add_executable(tile_example_grouped_gemm grouped_gemm.cpp) +rocm_install(TARGETS tile_example_grouped_gemm COMPONENT examples) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 9fbe65e3a7..3a70f0447d 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,4 +1,6 @@ -add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) +add_executable(tile_example_flatmm_basic flatmm_basic.cpp) +rocm_install(TARGETS tile_example_flatmm_basic COMPONENT examples) + set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt index a08fcebb74..10101e4d2e 100644 --- a/example/ck_tile/35_batched_transpose/CMakeLists.txt +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -1,9 +1,9 @@ set(TARGET_NAME tile_example_batched_transpose) -add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) +add_executable(${TARGET_NAME} batched_transpose_example.cpp batched_transpose_api.cpp) +rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) - diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 88efe0d8d9..16f68c6255 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -14,8 +14,11 @@ add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) add_subdirectory(13_moe_sorting) add_subdirectory(14_moe_smoothquant) -add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) + +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") + add_subdirectory(15_fused_moe) +endif() diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 611aff318f..ad6641bc13 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 0b38e7789e..893c9d1ad3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -30,8 +30,7 @@ struct GemmPipelineProblemBase using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - static constexpr bool TransposeC = Traits::TransposeC; - + static constexpr bool TransposeC = Traits::TransposeC; static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index a31004b425..ecf861e4e8 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -12,7 +12,8 @@ template + typename CLayout_, + bool UseStructuredSparsity_ = false> struct TileGemmTraits { static constexpr bool kPadM = kPadM_; @@ -27,7 +28,7 @@ struct TileGemmTraits using CLayout = CLayout_; static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; + static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; }; template Date: Mon, 28 Apr 2025 16:40:22 -0400 Subject: [PATCH 076/443] Check max-ilp-scheduling compiler option for moe_gemm examples (#2127) --- example/65_gemm_multiply_multiply/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 3c1947c058..5d2a097576 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -15,7 +15,10 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132) set(EXAMPLE_COMPILE_OPTIONS) - list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY) + if(HAS_MAX_ILP_SCHEDULING_STRATEGY) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + endif() target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) endif() From 4094ad158a81a6c4fa0681e6d1481fb18c0d2257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 28 Apr 2025 23:54:49 +0200 Subject: [PATCH 077/443] Integrate universal gemm with conv bwd data and add SplitK (#1315) * Integrate universal gemm with conv bwd data * Fix multi d kernel * Add splitK support * instances refactor * instances refactor * refactor * fixeS * fixes * 16x16 instnaces * Fixes * Fix * Fix * Fix * Fix * Fix * Fixes * fix * fix --- CHANGELOG.md | 1 + Jenkinsfile | 4 +- ...evice_grouped_conv_bwd_data_multiple_d.hpp | 5 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 27 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 1110 +++++++++++++++-- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 2 +- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 2 +- ...=> gridwise_gemm_xdl_cshuffle_conv_v3.hpp} | 2 +- .../transform_conv_bwd_data_to_gemm_v1.hpp | 68 +- ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 75 +- .../gpu/grouped_convolution_backward_data.hpp | 24 + .../grouped_convolution_backward_data_xdl.inc | 168 +++ .../grouped_conv2d_bwd_data/CMakeLists.txt | 6 + ...ta_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 28 +- ...ata_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 28 +- ...ata_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 28 +- ..._ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp | 40 + ...ta_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp | 2 +- ...kcyx_ngkhw_bf16_vec_transpose_instance.cpp | 2 +- ...l_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp | 40 + ...ata_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp | 2 +- ...gkcyx_ngkhw_f16_vec_transpose_instance.cpp | 2 +- ...l_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp | 40 + ...ata_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp | 2 +- ...gkcyx_ngkhw_f32_vec_transpose_instance.cpp | 2 +- ...ta_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp | 2 +- ...ata_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp | 2 +- ...ata_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp | 2 +- ..._nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp | 49 + ...ta_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 2 +- ...l_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp | 49 + ...ata_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 2 +- ...l_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp | 49 + ...ata_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 2 +- .../grouped_conv3d_bwd_data/CMakeLists.txt | 7 + ...xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp | 28 +- ..._xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 28 +- ..._xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp | 28 +- ...hwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp | 49 + ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 5 +- ...dhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp | 49 + ..._xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 5 +- ...dhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp | 49 + ..._xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 5 +- ..._ndhwgk_input_f16_comp_bf8_f8_instance.cpp | 5 +- ...cdhw_gkczyx_ngkdhw_bf16_16_16_instance.cpp | 40 + ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 3 +- ...zyx_ngkdhw_bf16_vec_transpose_instance.cpp | 3 +- ...gcdhw_gkczyx_ngkdhw_f16_16_16_instance.cpp | 40 + ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 3 +- ...czyx_ngkdhw_f16_vec_transpose_instance.cpp | 3 +- ...gcdhw_gkczyx_ngkdhw_f32_16_16_instance.cpp | 40 + ..._xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp | 3 +- ...czyx_ngkdhw_f32_vec_transpose_instance.cpp | 3 +- ...xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp | 3 +- ..._xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp | 3 +- ..._xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp | 3 +- ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 5 +- ...near_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 5 +- ...near_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 5 +- ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 5 +- ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 5 +- ...cale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 5 +- .../profile_grouped_conv_bwd_data_impl.hpp | 101 +- .../src/profile_grouped_conv_bwd_data.cpp | 8 +- script/convert_miopen_driver_to_profiler.py | 3 + test/grouped_convnd_bwd_data/CMakeLists.txt | 5 + .../test_grouped_convnd_bwd_data_xdl.cpp | 70 +- ...rouped_convnd_bwd_data_xdl_large_cases.cpp | 120 ++ 69 files changed, 2262 insertions(+), 349 deletions(-) rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp => gridwise_gemm_xdl_cshuffle_conv_v3.hpp} (99%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_16_16_instance.cpp create mode 100644 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index b9012c0a77..e0ec214c69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added GEMM pipeline for microscaling (MX) data types * Added support for FP16 2:4 structured sparsity to universal GEMM. +* Added support for Split K for grouped convolution backward data. ### Optimized diff --git a/Jenkinsfile b/Jenkinsfile index f8043ba918..a18374509e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -937,8 +937,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 2abf1d5a10..9c44bda5ca 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,7 +59,8 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator const std::array& input_right_pads, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) = 0; + const CDEElementwiseOperation& cde_element_op, + const ck::index_t split_k = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 359711e5c4..5e41c96dfc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) : p_a_grid_{static_cast(p_a)}, p_b_grid_{static_cast(p_b)}, p_ds_grid_{}, @@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} + input_right_pads_{input_right_pads}, + k_batch_{split_k} { // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle std::array conv_filter_strides_; std::array input_left_pads_; std::array input_right_pads_; + + const index_t k_batch_; }; // Invoker @@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(arg.k_batch_ != 1) + { + return false; + } + // check device if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { @@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) { return Argument{p_a, p_b, @@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle input_right_pads, a_element_op, b_element_op, - cde_element_op}; + cde_element_op, + split_k}; } static auto MakeInvoker() { return Invoker{}; } @@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) override + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override { return std::make_unique(p_a, p_b, @@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle input_right_pads, a_element_op, b_element_op, - cde_element_op); + cde_element_op, + split_k); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 08edddf107..3028cd7cbc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -15,12 +15,15 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" namespace ck { @@ -151,6 +154,153 @@ __global__ void #endif } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); + const index_t k_idx = + __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); + const index_t k_idx = + __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + } // namespace // Conv backward data multiple D: @@ -210,7 +360,9 @@ template + index_t MaxTransposeTransferOutScalarPerVector = 1, + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD 0; + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = + (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && + std::is_same_v, element_wise::PassThrough>; // TODO: Add support for different A and B data types. using ABDataType = ADataType; @@ -315,53 +472,63 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); } - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, - ABDataType, - AComputeType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOp, - BElementwiseOp, - CDEElementwiseOp, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched, - PipelineVersion::v1, - BComputeType>; +// GridwiseGemm +#define GridwiseGemmMultiDTemplateParams \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, InMemoryDataOperationEnum::Set, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + +#define GridwiseGemmTemplateParams \ + tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, \ + ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementwiseOp, \ + BElementwiseOp, CDEElementwiseOp, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, \ + AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ + AComputeType, BComputeType + + using GridwiseGemm = + std::conditional_t, + GridwiseGemm_xdl_cshuffle_v3>; + + template + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) + { + if constexpr(isMultiD) + { + return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + } + else + { + const index_t M = e_grid_desc_m_n.GetLength(I0); + const index_t N = e_grid_desc_m_n.GetLength(I1); + return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, + GridwiseGemm::CalculateMBlock(M), + GridwiseGemm::CalculateNBlock(N)); + } + } template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) @@ -390,15 +557,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{})); + decltype(GridwiseGemmMultipleD_xdl_cshuffle:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - EGridDesc_M_N{})); + decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; + using Block2ETileMap = remove_cvref_t< + decltype(GridwiseGemmMultipleD_xdl_cshuffle< + GridwiseGemmMultiDTemplateParams>::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -511,7 +678,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) : p_a_grid_{static_cast(p_a)}, p_b_grid_{static_cast(p_b)}, p_ds_grid_{}, @@ -525,7 +693,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} + input_right_pads_{input_right_pads}, + k_batch_{split_k} { std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, @@ -626,7 +795,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_filter_dilations, input_left_pads, input_right_pads, - tildes}; + tildes, + k_batch_}; conv_N_per_block_ = conv_to_gemm_transform_.N_; @@ -682,34 +852,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); - b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); - ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); - e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + if constexpr(isMultiD) + { + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + } // desc for blockwise copy a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - block_2_etile_map_container_.push_back(block_2_etile_map); - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) + if constexpr(isMultiD) { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + else + { + // there is no need to check since M, N, K are padded e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n)); } } @@ -844,7 +1028,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemmMultipleD_xdl_cshuffle::DsGridPointer + p_ds_grid_; EDataType* p_e_grid_; // tensor descriptor for problem definition @@ -891,6 +1076,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::array input_left_pads_; std::array input_right_pads_; + const index_t k_batch_; index_t num_workgroups_per_Conv_N_; }; @@ -899,7 +1085,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; - float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -998,6 +1184,678 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + const auto num_k_per_block = + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(Number<0>{}) / arg.k_batch_; + + // gdy is for the kbatch and num_workgrups_per_Conv_N + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + GemmM, GemmN, arg.k_batch_ * arg.num_workgroups_per_Conv_N_, arg.num_group_); + + index_t k_grain = arg.k_batch_ * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + const bool has_main_k_block_loop = + GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper + rotating_mem(gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + num_k_per_block); + } + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + if constexpr(IsSplitKSupported) + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if constexpr(IsSplitKSupported) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::One) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Two) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Six) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp:: + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Two) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Six) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if constexpr(IsSplitKSupported) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if constexpr(IsSplitKSupported) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + if constexpr(IsSplitKSupported) + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + } + return ave_time; + } + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1084,7 +1942,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, std::array{0}); } - ave_time += RunGemm(arg, stream_config); + + if constexpr(isMultiD) + { + ave_time += RunMultiDGemm(arg, stream_config); + } + else + { + ave_time += RunGemmV3(arg, stream_config); + } + // Transpose from NHWGC to NGCHW if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) @@ -1148,10 +2015,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return false; } + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.k_batch_ > 1) + { + return false; + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ != 1) + { + return false; + } + } + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + if constexpr(!isMultiD) + { + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + } + // Specifialization if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) @@ -1254,13 +2158,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // Gridwise GEMM size for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) + if constexpr(isMultiD) { - return false; + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i])) + { + return false; + } } } @@ -1335,7 +2242,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) { return Argument{p_a, p_b, @@ -1355,7 +2263,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads, a_element_op, b_element_op, - cde_element_op}; + cde_element_op, + split_k}; } static auto MakeInvoker() { return Invoker{}; } @@ -1381,7 +2290,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) override + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override { return std::make_unique(p_a, p_b, @@ -1401,7 +2311,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads, a_element_op, b_element_op, - cde_element_op); + cde_element_op, + split_k); } std::unique_ptr MakeInvokerPointer() override @@ -1413,6 +2324,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { auto str = std::stringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" << "<" @@ -1430,7 +2352,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle; + << CShuffleNXdlPerWavePerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]; if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index da7c4f759b..c7d95254c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -19,7 +19,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index d56c7abcde..dd5b97096d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -17,7 +17,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 4f5fedcd83..d37b3cd38e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 0ddfd0a7c8..a191c75099 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -187,7 +187,8 @@ struct TransformConvBwdDataToGemm_v1 WTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.WTilde_)}, ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)}, YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)}, - XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)} + XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)}, + batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_} { } @@ -203,7 +204,8 @@ struct TransformConvBwdDataToGemm_v1 const ConvSpatialDimsType& conv_filter_dilations, const ConvSpatialDimsType& input_left_pads, const ConvSpatialDimsType& input_right_pads, - const ConvSpatialDimsType& tildes) + const ConvSpatialDimsType& tildes, + const index_t batch_k = 1) : Hi_{c_g_n_c_wis_lengths[HIdx]}, Wi_{c_g_n_c_wis_lengths[WIdx]}, Ho_{a_g_n_k_wos_lengths[HIdx]}, @@ -231,7 +233,8 @@ struct TransformConvBwdDataToGemm_v1 InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]}, InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]}, IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]}, - IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]} + IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}, + batch_k_{batch_k} { static_assert(is_same_v> || is_same_v>); @@ -616,20 +619,22 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t AK0 = math::integer_divide_ceil(K_, AK1); + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( out_grid_desc, make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), - make_unmerge_transform(make_tuple(AK0, AK1))), + make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{})); const auto out_gemmak0_gemmm_gemmak1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmak0_gemmmraw_gemmak1_grid_desc, - make_tuple(AK0, GemmMPerBlock, AK1), + make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1), Sequence{}); return out_gemmak0_gemmm_gemmak1_grid_desc; @@ -719,11 +724,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), make_pass_through_transform( out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -816,11 +825,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), make_pass_through_transform( out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -850,21 +863,23 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t BK0 = math::integer_divide_ceil(K_, BK1); + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock; // B: weight tensor - const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, - make_tuple(BK0, GemmNPerBlock, BK1), + make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1), Sequence{}); return wei_gemmbk0_gemmn_gemmbk1_grid_desc; @@ -925,11 +940,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -1006,11 +1025,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -1355,6 +1378,7 @@ struct TransformConvBwdDataToGemm_v1 IndexType ZTilde_, YTilde_, XTilde_; IndexType DTilde_, HTilde_, WTilde_; IndexType ZDot_, YDot_, XDot_; + index_t batch_k_; }; } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index ae6fabd0bd..5c0d7283f2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -54,6 +54,28 @@ using device_grouped_conv_bwd_data_xdl_f16_generic_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f16_16_16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, @@ -108,6 +130,27 @@ using device_grouped_conv_bwd_data_xdl_bf16_generic_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, @@ -162,6 +205,28 @@ using device_grouped_conv_bwd_data_xdl_f32_generic_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f32_16_16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, @@ -194,7 +259,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -218,7 +283,7 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 12695f4f16..e9ff75a91d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -109,6 +109,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -117,6 +119,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -126,6 +130,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + op_ptrs); } #endif } @@ -167,6 +173,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( op_ptrs); } @@ -177,6 +185,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( op_ptrs); } @@ -188,6 +198,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( op_ptrs); } @@ -237,6 +249,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + op_ptrs); } #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 @@ -255,6 +269,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -264,6 +280,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + op_ptrs); } #endif } @@ -308,6 +326,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( op_ptrs); } @@ -319,6 +339,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( op_ptrs); } @@ -330,6 +352,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( op_ptrs); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 5be8f29e99..c723be0db8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -69,6 +69,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( @@ -84,6 +98,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( @@ -99,6 +127,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -162,6 +204,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -310,6 +408,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( @@ -325,6 +437,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( @@ -403,6 +529,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_bf16_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 64fbf8bbf2..1a3c80e5cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] + void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f16_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index f9351d96f2..96623a5161 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] + void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f32_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..f3aded5043 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 23aeeaf505..e8c6bc7cbe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index beeda26690..3f94d30a55 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp index a1d768f4eb..b5e89c9b7c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 75e7f61f8a..11e0fc6073 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 231e894be0..a63dd712b6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index dbaece1123..e4b4165928 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_bf16_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 1885d49c81..03b8285631 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, + // wo, k] void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f16_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index 77135fcc05..59526ba9bc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, + // wo, k] void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f32_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..3f90c8b907 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 663d41fe0b..f9989dec13 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index ac0ab44ce3..071d34b94a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 50d5cce73d..77127bf7f9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp index a9a6b4d281..943c5bab26 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp index eec3944078..bada2507c2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp index a596482ca8..f1c6f53bf3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( std::vectorGetWorkSpaceSize(argument_ptr.get()); DeviceMem workspace_dev(workspace_sz); @@ -150,7 +154,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / avg_time; std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << ", SplitK " << split_k_for_run + << std::endl; if(tflops > best_tflops) { @@ -158,13 +163,39 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, best_tflops = tflops; best_avg_time = avg_time; best_gb_per_sec = gb_per_sec; + best_split_k = split_k_for_run; } if(do_verification) { in_device_buf.FromDevice(in_device.mData.data()); - pass = pass & ck::utils::check_err(in_device, in_host); + using ComputeType = std::conditional_t; + using AccDataType = + std::conditional_t, int32_t, float>; + const index_t num_accums = conv_param.K_; + // Calculate thresholds + auto rtol = ck::utils::get_relative_threshold( + num_accums / split_k_for_run); + auto atol = ck::utils::get_absolute_threshold( + max_accumulated_value / split_k_for_run, num_accums / split_k_for_run); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + split_k_for_run); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, split_k_for_run); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + pass = pass & ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; if(do_log) { @@ -225,35 +256,47 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); + std::vector split_k_list = {1, 2, 4, 8, 16, 32, 64, 128}; + + if(split_k > 0) + { + split_k_list = {split_k}; + } + for(auto& op_ptr : op_ptrs) { - auto argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - {}, - static_cast(in_device_buf.GetDeviceBuffer()), - out_lengths, - out_strides, - wei_lengths, - wei_strides, - {}, - {}, - in_lengths, - in_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - out_element_op, - wei_element_op, - in_element_op); + for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + {}, + static_cast(in_device_buf.GetDeviceBuffer()), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op, + split_k_list[split_k_id]); - run_impl(op_ptr, argument_ptr); + run_impl(op_ptr, argument_ptr, split_k_list[split_k_id]); + } } std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " + << best_split_k << std::endl; return pass; } diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 1515f1105f..5cdece499e 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -68,8 +68,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) const bool time_kernel = std::stoi(argv[7]); const int num_dim_spatial = std::stoi(argv[8]); - // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial - if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) { print_helper_msg(); return 1; @@ -77,6 +77,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); + using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -110,7 +112,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) OutDataType, WeiDataType, InDataType>( - do_verification, init_method, do_log, time_kernel, params); + do_verification, init_method, do_log, time_kernel, params, split_k); return pass ? 0 : 1; }; diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 1278b6744d..2ddcbb67cd 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -126,6 +126,8 @@ def run_ck_grouped_conv_bwd_data(args): args.ck_profier_op = "grouped_conv_bwd_data" parse_data_type(args) parse_layouts(args) + # Test all split K value from the list {1, 2, 4, 8, 32, 64, 128} + args.split_k_value = -1 cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] cmd += [str(args.data_type), str(args.layout)] @@ -136,6 +138,7 @@ def run_ck_grouped_conv_bwd_data(args): cmd += [str(args.in_channels)] add_conv_params_to_cmd(args, cmd) + cmd += [str(args.split_k_value)] run_ck_profiler_cmd(cmd) diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 6d78da8db7..5c816da416 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -2,6 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_da if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() +if(GPU_TARGETS MATCHES "gfx9") + add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp) + target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp index eb6083c521..c4404b95ba 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp @@ -21,26 +21,31 @@ class TestGroupedConvndBwdDataXdl : public ::testing::Test using InLayout = std::tuple_element_t<3, Tuple>; std::vector conv_params; + std::vector split_ks{1, 2}; template void Run() { EXPECT_FALSE(conv_params.empty()); bool pass = true; - for(auto& param : conv_params) + for(auto split_k : split_ks) { - pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param); + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } } EXPECT_TRUE(pass); } @@ -92,19 +97,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) this->conv_params.clear(); this->conv_params.push_back( - {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - // SplitN case - this->conv_params.push_back( - {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } @@ -112,28 +114,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) { this->conv_params.clear(); this->conv_params.push_back( - {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( - {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // SplitN case - this->conv_params.push_back({3, - 1, - 128, - 4, - 192, - {2, 2, 2}, - {2, 224, 224}, - {1, 224, 224}, - {1, 1, 1}, - {0, 0, 0}, - {0, 0, 0}}); + {3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp new file mode 100644 index 0000000000..73d793cc5f --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" + +template +class TestGroupedConvndBwdDataXdl : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + std::vector split_ks{1, 2}; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto split_k : split_ks) + { + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl +{ +}; + +template +class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) +{ + this->conv_params.clear(); + // SplitN case + this->conv_params.push_back( + {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) +{ + this->conv_params.clear(); + // SplitN case + this->conv_params.push_back({3, + 1, + 128, + 4, + 192, + {2, 2, 2}, + {2, 224, 224}, + {1, 224, 224}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + this->template Run<3>(); +} From 768c99eca9e6a4e4edc4e6b920939933eafb4aea Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 28 Apr 2025 18:19:23 -0700 Subject: [PATCH 078/443] [TileEngine] Support for sparsity in codegen (#2128) * Added sparsity flag in codegen * remove comments * clan formatted * added sparsity as runtime argument * updated README * updated stream config variable * fix typo for tail_num in hot loop --- tile_engine/ops/gemm/README.md | 37 ++++++++++--------- tile_engine/ops/gemm/gemm_host_api.cpp | 20 ++++++++-- tile_engine/ops/gemm/gemm_host_api.hpp | 1 + tile_engine/ops/gemm/gemm_instance_builder.py | 33 ++++++++++------- 4 files changed, 56 insertions(+), 35 deletions(-) mode change 100644 => 100755 tile_engine/ops/gemm/gemm_host_api.cpp mode change 100644 => 100755 tile_engine/ops/gemm/gemm_host_api.hpp diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 495232f19b..08456a1675 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -20,24 +20,25 @@ make tile_engine_gemm -j ## tile_engine_gemm inputs ``` - -m m dimension (default:3840) - -n n dimension (default:4096) - -k k dimension (default:2048) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -split_k SplitK value (default:1) - -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) - -warmup Number of iterations before benchmark the kernel (default:50) - -repeat Number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) - -pipeline possible values are: compv3, compv4, mem (default:compv3) - -scheduler possible values are: intrawave, interwave (default:intrawave) - -epilogue possible values are: cshuffle, default (default:cshuffle) - -pad_m Pad in m direction - true/false (default:false) - -pad_n Pad in n direction - true/false (default:false) - -pad_k Pad in k direction - true/false (default:false) + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -split_k SplitK value (default:1) + -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -warmup Number of iterations before benchmark the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) +-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0) + -pipeline possible values are: compv3, compv4, mem (default:compv3) + -scheduler possible values are: intrawave, interwave (default:intrawave) + -epilogue possible values are: cshuffle, default (default:cshuffle) + -pad_m Pad in m direction - true/false (default:false) + -pad_n Pad in n direction - true/false (default:false) + -pad_k Pad in k direction - true/false (default:false) Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json ``` diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp old mode 100644 new mode 100755 index 3cef425a51..a5447cd658 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -10,12 +10,19 @@ void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, + bool structured_sparsity, KernelTraits& trait, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + const ck_tile::stream_config& stream) { - return GemmDispatcher::dispatch( - c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, trait, args, s); + return GemmDispatcher::dispatch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + structured_sparsity, + trait, + args, + stream); } template {}(a_m_k); + } + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -153,6 +166,7 @@ void run(const ck_tile::ArgParser& arg_parser) c_m_n_host_result, c_m_n_dev_result, verify, + structured_sparsity, trait, gemm_args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp old mode 100644 new mode 100755 index c1e1e1dc4f..579d2770db --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -118,6 +118,7 @@ inline auto create_args(int argc, char* argv[]) .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("structured_sparsity", "0", "0:false, 1:true") .insert("pipeline", "compv3", "compv3, compv4, mem") .insert("scheduler", "intrawave", "intrawave, interwave") .insert("epilogue", "cshuffle", "cshuffle, default") diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index cfefd38cd2..b6c7685fb2 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -69,7 +69,7 @@ HOT_LOOP_FALSE = """ else if(tail_num == ck_tile::TailNumber::Even) { Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else { @@ -347,7 +347,8 @@ namespace {group_name} {{ return f""" template + int WarpTileM, int WarpTileN, int WarpTileK, + bool structured_sparsity> struct GemmKernel {{ static constexpr bool kPadM = {BOOL_MAP(kPadM)}; static constexpr bool kPadN = {BOOL_MAP(kPadN)}; @@ -356,7 +357,7 @@ struct GemmKernel {{ static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; static constexpr bool TransposeC = false; static constexpr int kBlockPerCu = 1; @@ -381,7 +382,7 @@ struct GemmKernel {{ using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + ALayout, BLayout, CLayout, TransposeC, structured_sparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -494,7 +495,7 @@ struct GemmDispatcher { return kernel_map; } - static void init() { + static void init(bool structured_sparsity) { auto& kernel_map = get_kernel_map(); if(!kernel_map.empty()) return; \n""" @@ -513,11 +514,11 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [](ck_tile::DeviceMem& c_m_n_dev_buf, + content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) {{ + const ck_tile::stream_config& stream) {{ """ for tile in tile_params: # Check if we have valid tile/warp combinations @@ -526,7 +527,11 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, s);""" + if(structured_sparsity) {{ + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); + }} else {{ + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); + }}""" content += f""" }};\n""" @@ -536,9 +541,9 @@ struct GemmDispatcher { static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { - float avg_time = Kernel::launch(args, s); + float avg_time = Kernel::launch(args, stream); std::string description = Kernel::get_name(); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); @@ -559,13 +564,13 @@ struct GemmDispatcher { static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, - int verify, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, - const ck_tile::stream_config& s) { - init(); + int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + const ck_tile::stream_config& stream) { + init(structured_sparsity); const std::string key = assemble_key(trait); auto& kernel_map = get_kernel_map(); if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify,gemm_args, s); + return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream); } throw std::runtime_error("No suitable kernel found: " + key); } From d107f3c3a53b6582a073e906133a9b05502352e8 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 28 Apr 2025 18:19:50 -0700 Subject: [PATCH 079/443] Support for MFMA_16x16x128 for fp8/bf8 (#2125) * Adding 16x16x128 support for gfx950 * Support for fp8 and bf8 * fix input arguments for MFMA scale instruction * clang-formatted * Fixes for lwpck-3145 (#2138) * Fix lds tile & cmake dep & default epilogue * Fallback BTypeToUse to ADataType in WOQ cases * reverting instance json file * reverting instance json file --------- Co-authored-by: Yi DING --- .../ops/epilogue/cshuffle_epilogue.hpp | 3 +- .../ops/epilogue/default_2d_epilogue.hpp | 21 ++-- .../block/block_universal_gemm_as_bs_cr.hpp | 2 +- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 +++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 98 +++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 5 + tile_engine/ops/gemm/CMakeLists.txt | 8 +- tile_engine/ops/gemm/gemm_instance_builder.py | 4 +- 8 files changed, 143 insertions(+), 10 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 0081edcb2e..225997439e 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -49,8 +49,9 @@ struct CShuffleEpilogue using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = - std::conditional_t, ODataType, BDataType>; + std::conditional_t, ADataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 6e290fe6d7..1d6a99eb4b 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -25,7 +25,9 @@ struct Default2DEpilogueProblem static constexpr bool UseRawStore = UseRawStore_; }; -template { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; @@ -96,17 +100,22 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - using WG = WarpGemmMfmaDispatcher(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 4732027e57..22962b9404 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -216,6 +216,18 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl>>; +using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 08f813a1e3..cd32f35180 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1342,6 +1342,104 @@ template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 128; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + // int8 template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index f437ee10c5..0e3342c479 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -69,6 +69,11 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; + // clang-format on } // namespace impl diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index d28017ca0c..bc613a931e 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -8,6 +8,10 @@ execute_process( --list_blobs RESULT_VARIABLE ret ) +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json +) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") @@ -21,7 +25,9 @@ add_custom_command( --working_path ${CMAKE_CURRENT_BINARY_DIR} --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json --gen_blobs - DEPENDS ${GEMM_CODEGEN_BLOBS} + DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json ) set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index b6c7685fb2..b441bdd2d6 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -27,7 +27,9 @@ LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem Date: Mon, 28 Apr 2025 20:20:47 -0500 Subject: [PATCH 080/443] Add Matrix A and Matrix B Swizzle for LDS in Computev4 policy (#2136) * fixed computev4 policy bug for lds swizzle * added swizzle for input matrix B * Improved ComputeV4 policy and pipeline by swizzling A and B * consolidated LDS descriptor functions in parent struct --- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 48 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 50 -- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 482 +++++++++--------- 3 files changed, 265 insertions(+), 315 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 0e0ee9dbd8..667bb80ce9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -217,17 +217,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 ////////////// global window & register ///////////////// // A DRAM tile window for load auto a_copy_dram_window = - make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // B DRAM tile window for load auto b_copy_dram_window = - make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); // A register tile for global load constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); @@ -317,25 +317,25 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 BLdsTile b_block_tile1; auto a_lds_ld_window0 = - make_tile_window_linear(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto a_lds_ld_window1 = - make_tile_window_linear(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto b_lds_ld_window0 = - make_tile_window_linear(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); auto b_lds_ld_window1 = - make_tile_window_linear(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index e528847438..f6920f1c57 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -17,56 +17,6 @@ namespace ck_tile { struct GemmPipelineAgBgCrCompV4DefaultPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{} / KPack, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackB(); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kNPerBlock)*KPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index b555cf75e0..6890cf2f64 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -19,6 +19,245 @@ struct UniversalGemmBasePolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + /** + * @brief Create LDS block descriptor for B tensor. + * + * @tparam Problem Gemm pipeline problem. + * @return B tensor LDS block descriptor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + // using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + +#if 1 + // if constexpr(std::is_same_v) + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = TileDistributionEncodingPattern2D; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif + } + /** * @brief Get the maximum global memory vector load size. * @@ -301,7 +540,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); constexpr index_t smem_size_a = integer_least_multiple( sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); return smem_size_a; @@ -310,7 +549,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); return smem_size_b; @@ -330,245 +569,6 @@ struct UniversalGemmBasePolicy struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using ADataType = remove_cvref_t; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - /** - * @brief Create LDS block descriptor for B tensor. - * - * @tparam Problem Gemm pipeline problem. - * @return B tensor LDS block descriptor. - */ - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - // using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - -#if 1 - // if constexpr(std::is_same_v) - { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - BK0 * number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } -#else - else // B is Row Major - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = TileDistributionEncodingPattern2D; - - constexpr auto BK0 = number{}; - constexpr auto BK1 = number{}; - // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N0 = TileEncodingPattern::X0; - constexpr auto N1 = NPerBlock / N0; - - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; - - // constexpr auto KThreadWrite = - // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; - constexpr auto K0PerThreadRead = BK0 / KThreadRead; - - constexpr auto kfold = - (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - // b_lds_block_desc_unmerged, - // make_tuple(make_merge_transform_v3_division_mod( - // make_tuple(number{}, - // number{}, - // number{}, - // number{})), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}, number{})), - // make_pass_through_transform(BK1)), - // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), - // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - // return b_lds_block_desc_bk0_n_bk1; - return b_lds_block_desc_kn; - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( - // make_tuple(BK0, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); - - // constexpr auto b_lds_block_desc = transform_tensor_descriptor( - // b_lds_block_desc_bk0_n_bk1, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod(make_tuple(BK0, - // number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - - // return b_lds_block_desc; - } -#endif - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { From 8fcb4dff1af2c44581a01607626927dd23297163 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 29 Apr 2025 07:35:10 -0700 Subject: [PATCH 081/443] Run CI jobs as user jenkins (#2141) * run CI as jenkins * remove user jenkins from docker image * move inductor installation to a writeable path * add a switch for inductor tests --- Dockerfile | 1 - Jenkinsfile | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index f77c685000..3cac1dde4c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,6 @@ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn # Add rocm repository RUN set -xe && \ - useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg diff --git a/Jenkinsfile b/Jenkinsfile index a18374509e..c46e2d53ef 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -395,7 +395,7 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -464,7 +464,7 @@ def Build_CK(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -527,10 +527,10 @@ def Build_CK(Map conf=[:]){ arch_type = 6 } cmake_build(conf) - if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){ echo "Run inductor codegen tests" sh """ - pip install --break-system-packages --verbose . + pip install --target ${env.WORKSPACE} --break-system-packages --verbose . pytest python/test/test_gen_instances.py """ } @@ -625,10 +625,6 @@ def Build_CK(Map conf=[:]){ """ } } - // set ownership of all files and folders to jenkins after all steps completed - dir("build"){ - sh "sudo chown -R jenkins:jenkins ../*" - } } } } @@ -843,6 +839,10 @@ pipeline { name: "BUILD_LEGACY_OS", defaultValue: false, description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") + booleanParam( + name: "RUN_INDUCTOR_TESTS", + defaultValue: false, + description: "Run inductor codegen tests (default: OFF)") } environment{ dbuser = "${dbuser}" From 6601931949dc385f78d24c4688369535d0f5315c Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:22:38 -0700 Subject: [PATCH 082/443] try building ck4inductor and testing it inside a virtual environment (#2142) use system virtualenv use python-full ubuntu package in docker image --------- Co-authored-by: illsilin Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- Dockerfile | 4 +--- Jenkinsfile | 7 +++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3cac1dde4c..c629bd034c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,9 +49,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- mpich \ net-tools \ pkg-config \ - python3 \ - python3-dev \ - python3-pip \ + python3-full \ redis \ rocm-llvm-dev \ sshpass \ diff --git a/Jenkinsfile b/Jenkinsfile index c46e2d53ef..3e22eb2f01 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -530,8 +530,11 @@ def Build_CK(Map conf=[:]){ if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){ echo "Run inductor codegen tests" sh """ - pip install --target ${env.WORKSPACE} --break-system-packages --verbose . - pytest python/test/test_gen_instances.py + python3 -m venv ${env.WORKSPACE} + . ${env.WORKSPACE}/bin/activate + python3 -m pip install pytest build setuptools setuptools_scm + python3 -m pip install . + python3 -m pytest python/test/test_gen_instances.py """ } dir("build"){ From 1aea51d34eb17507b141ac9d6b36516bcc4bc584 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 29 Apr 2025 19:37:07 -0500 Subject: [PATCH 083/443] [Tile Engine] Improved README.md (#2134) * improved tile_engine readme * changed ck tile explanation and json * further improved readme * fixed typo --- tile_engine/ops/gemm/README.md | 58 ++++++++++++++++--- .../gemm/configs/instance_combination.json | 6 +- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index 08456a1675..f7d86e90fe 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -1,22 +1,30 @@ # GEMM Matrix Multiplication -Use the files in this folder to generate and build applications that run Matrix multiplications using ck_tile programming based on the kernel parameters mentioned in the config file `./configs/instance_combination.json`. +CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues. # Kernel Configurations -User needs to provide kernel configuration such as datatype, layout, tile size, warp size, padding, pipeline, scheduler and epilogue in the config file. For reference please see `./configs/instance_combination.json` +Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes. -## Build -``` -# in the root of ck_tile +Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels. + + +## Build Instructions +``` bash +# in the root of composable kernel create build directory mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ -# To generate the executable +# build composable kernel +sh ../script/cmake-ck-dev.sh ../ # replace with the appropriate architecture (example gfx942) or leave blank +# generate the executable make tile_engine_gemm -j ``` `tile_engine_gemm` will be located in the `./bin/` directory. +_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._ +``` bash +rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild +``` + ## tile_engine_gemm inputs ``` @@ -42,11 +50,43 @@ make tile_engine_gemm -j Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json ``` +Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. ## Example -Below example will run gemm kernel with default dimensions of matrices, for compv3 pipeline, intrawave scheduler and default epilogue with all possible tile sizes mentioned in Config file. +The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. +```json +{ + /// other parameters /// + + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [64, 32] + }, + + /// other parameters /// + + "pipeline": { + "values": ["compv3", "compv4", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["default", "cshuffle"] + } +} ``` + +At runtime, a specific subset of the generated kernels can be selected using command-line arguments. +``` bash ./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default ``` +The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. + diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index e23df11500..66dbdafa11 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -7,10 +7,10 @@ "values": ["c"] }, "layout_c": { - "values": ["r"] + "values": ["r"] }, "datatype": { - "values": ["fp16"] + "values": ["fp16"] }, "tile_m": { "values": [256] @@ -49,7 +49,7 @@ "values": [false] }, "pipeline": { - "values": ["compv3", "mem"] + "values": ["compv3", "compv4", "mem"] }, "scheduler": { "values": ["intrawave", "interwave"] From 23de234dbeb23f9f304b33e6b2de91639da62941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 30 Apr 2025 09:49:37 +0200 Subject: [PATCH 084/443] Add grouped conv fwd 16x16 mfma instruction instances (#2140) * Add grouped conv fwd 16x16 mfma instruction instances * fix * remove oddc * fix * fix --- ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 23 +++- .../gpu/grouped_convolution_forward.hpp | 8 -- .../gpu/grouped_convolution_forward_wmma.inc | 111 ------------------ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 4 - ...ma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp | 40 ------- ...mma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp | 40 ------- ...ma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp | 40 ------- ...mma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp | 40 ------- ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 9 -- ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 8 -- ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 9 -- ...nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp | 9 -- ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 8 -- ...gc_gkyxc_nhwgk_f16_comp_part2_instance.cpp | 9 -- ...dl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp | 10 +- ...l_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp | 28 +---- ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp | 10 +- ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 11 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 - ...gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp | 41 ------- ..._gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp | 41 ------- ...ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp | 41 ------- ..._ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp | 41 ------- 36 files changed, 36 insertions(+), 686 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index f491474d38..6c0ba2f932 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -89,7 +90,12 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // mfma 16x16 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; @@ -140,7 +146,12 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + // mfma 16x16 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; @@ -184,7 +195,11 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + // mfma 16x16 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 638a3f98a3..d5eed7592e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -599,7 +599,6 @@ struct DeviceOperationInstanceFactory>>& instances); -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -291,20 +236,6 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index c1790901ec..3a101baac0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -87,8 +87,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp ## NHWGC, GKYXC, NHWGK wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -96,6 +94,4 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp deleted file mode 100644 index a8f723dfec..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp deleted file mode 100644 index 784a118897..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp deleted file mode 100644 index 8c621543a9..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp deleted file mode 100644 index 5cb313b3ca..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp index c078f8ed04..f5df7278d0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index a67b11f1cf..db048679bd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -49,14 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp index 5c0391a25f..ee9507a80a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instanc Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp index 726276c461..132d3c8411 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index 8b7bdec2a8..a7deb969ba 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -49,14 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp index c66114b9a3..d2732547fa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp index 93e07e08fb..8a0caebc9f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -48,14 +48,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp index 6acbb7475c..e45df1e107 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -50,14 +50,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( NHWGK, ConvFwd1x1S1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - if(ck::get_device_name() != "gfx950") { add_device_operation_instances( @@ -86,15 +78,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } if(ck::get_device_name() == "gfx950") @@ -125,15 +108,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 2afbfdc386..078221f89f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 822ef51e00..3a481dd204 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 79a1fb99a8..5add0f8add 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index e567c0df75..0257c7d315 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 3e42184996..2715506fe2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index c035d4c3da..8d3e4d91b1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp index 5c425effd8..465fa927a5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp index e8a763c527..87423801cb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp index 3ae3fb5186..ebb213461a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp index cb7e912936..c2c8a099b2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp index d787f4b048..11cb853f0d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp index 5644289790..1992d7f7c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp index 5b12dad5a3..2b8fd3d9db 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp index f667481fa4..5579ec62cc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp index 2ff2c7f51f..77f3df2c11 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 7b9ccf6609..eeea4aae6d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -60,10 +60,6 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp deleted file mode 100644 index fa378af1ee..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, -// wo, k] -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp deleted file mode 100644 index d41416fd4a..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, -// wo, k] -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp deleted file mode 100644 index 8a7bc26178..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, -// g, k] -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp deleted file mode 100644 index 7649f86971..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, -// g, k] -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 9a9f59ae69a619e2d6ce3c8ff343f3c4b0ada413 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 30 Apr 2025 10:20:16 -0700 Subject: [PATCH 085/443] Revert "Add ck tile examples to package (#1880)" (#2150) --- example/CMakeLists.txt | 4 +--- example/ck_tile/01_fmha/CMakeLists.txt | 6 ++---- example/ck_tile/02_layernorm2d/CMakeLists.txt | 3 +-- example/ck_tile/03_gemm/CMakeLists.txt | 7 ++----- example/ck_tile/03_gemm/stript.sh | 1 - example/ck_tile/04_img2col/CMakeLists.txt | 3 +-- example/ck_tile/05_reduce/CMakeLists.txt | 4 +--- example/ck_tile/06_permute/CMakeLists.txt | 3 +-- .../ck_tile/09_topk_softmax/CMakeLists.txt | 5 ++--- example/ck_tile/10_rmsnorm2d/CMakeLists.txt | 6 ++---- .../11_add_rmsnorm2d_rdquant/CMakeLists.txt | 6 ++---- .../add_rmsnorm2d_rdquant_fwd.cpp | 21 ++++++++----------- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 21 ++++++++----------- example/ck_tile/12_smoothquant/CMakeLists.txt | 3 +-- example/ck_tile/13_moe_sorting/CMakeLists.txt | 3 +-- .../ck_tile/14_moe_smoothquant/CMakeLists.txt | 3 +-- example/ck_tile/15_fused_moe/CMakeLists.txt | 3 +-- .../ck_tile/16_batched_gemm/CMakeLists.txt | 3 +-- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 4 ++-- example/ck_tile/18_flatmm/CMakeLists.txt | 4 +--- .../35_batched_transpose/CMakeLists.txt | 4 ++-- example/ck_tile/CMakeLists.txt | 5 +---- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1 - .../gemm/pipeline/gemm_pipeline_problem.hpp | 3 ++- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 5 ++--- 25 files changed, 48 insertions(+), 83 deletions(-) delete mode 100644 example/ck_tile/03_gemm/stript.sh mode change 100755 => 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 0e61fd33ef..996a543ecc 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,7 @@ include_directories(BEFORE add_custom_target(examples) + # list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds) # all other tests are labelled as SMOKE_EXAMPLE set(REGRESSION_EXAMPLES @@ -231,9 +232,6 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) -if (NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9") - list(FILTER dir_list EXCLUDE REGEX ".*/ck_tile") -endif() FOREACH(subdir ${dir_list}) if(IS_DIRECTORY "${subdir}" AND EXISTS "${subdir}/CMakeLists.txt") add_subdirectory(${subdir}) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index ce3c8b3978..9ba3a453fc 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -58,8 +58,7 @@ set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message("adding example ${EXAMPLE_FMHA_FWD}") -add_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp) -rocm_install(TARGETS ${EXAMPLE_FMHA_FWD} COMPONENT examples) +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) @@ -67,8 +66,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message("adding example ${EXAMPLE_FMHA_BWD}") -add_executable(${EXAMPLE_FMHA_BWD} fmha_bwd.cpp) -rocm_install(TARGETS ${EXAMPLE_FMHA_BWD} COMPONENT examples) +add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index 74f195a9db..fa69ac0f7a 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -26,8 +26,7 @@ add_custom_command( set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") -add_executable(${EXAMPLE_LAYERNORM2D_FWD} layernorm2d_fwd.cpp) -rocm_install(TARGETS ${EXAMPLE_LAYERNORM2D_FWD} COMPONENT examples) +add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index deccb71d23..411db2e317 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,8 +1,5 @@ -add_executable(tile_example_gemm_basic gemm_basic.cpp) -rocm_install(TARGETS tile_example_gemm_basic COMPONENT examples) -add_executable(tile_example_gemm_universal universal_gemm.cpp) -rocm_install(TARGETS tile_example_gemm_universal COMPONENT examples) - +add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/03_gemm/stript.sh b/example/ck_tile/03_gemm/stript.sh deleted file mode 100644 index 4b91cb36ce..0000000000 --- a/example/ck_tile/03_gemm/stript.sh +++ /dev/null @@ -1 +0,0 @@ -for file in gemm_universal_*; do mv "$file" "${file/f16_f16_f16/fp16_fp16_fp16}"; done diff --git a/example/ck_tile/04_img2col/CMakeLists.txt b/example/ck_tile/04_img2col/CMakeLists.txt index d3737467d8..3864c9ed9d 100644 --- a/example/ck_tile/04_img2col/CMakeLists.txt +++ b/example/ck_tile/04_img2col/CMakeLists.txt @@ -1,4 +1,3 @@ # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_img2col image_to_column.cpp) -rocm_install(TARGETS tile_example_img2col COMPONENT examples) +add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 855e59c48e..6caa38d50d 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -3,9 +3,7 @@ set(EXAMPLE_REDUCE "tile_example_reduce") # to be included in "make all/install/check" message("adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} reduce.cpp) -rocm_install(TARGETS ${EXAMPLE_REDUCE} COMPONENT examples) - +add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp) target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) diff --git a/example/ck_tile/06_permute/CMakeLists.txt b/example/ck_tile/06_permute/CMakeLists.txt index 22483a4295..327fceb685 100644 --- a/example/ck_tile/06_permute/CMakeLists.txt +++ b/example/ck_tile/06_permute/CMakeLists.txt @@ -1,7 +1,6 @@ # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_permute permute.cpp) -rocm_install(TARGETS tile_example_permute COMPONENT examples) +add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp) if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) # set(PERMUTE_USE_ALTERNATIVE_IMPL false) diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt index fc2a4d3fe0..b43b989792 100644 --- a/example/ck_tile/09_topk_softmax/CMakeLists.txt +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -1,7 +1,6 @@ -add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) -rocm_install(TARGETS tile_example_topk_softmax COMPONENT examples) - +add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp) target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt index 731ff639a4..5684c9b2e0 100644 --- a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -26,8 +26,7 @@ add_custom_command( set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") message("adding ${TILE_RMSNORM2D_FWD}") -add_executable(${TILE_RMSNORM2D_FWD} rmsnorm2d_fwd.cpp) -rocm_install(TARGETS ${TILE_RMSNORM2D_FWD} COMPONENT examples) +add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) @@ -39,8 +38,7 @@ list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") -add_executable(${EXAMPLE_RMSNORM2D_FWD} example_rmsnorm2d_fwd.cpp) -rocm_install(TARGETS ${EXAMPLE_RMSNORM2D_FWD} COMPONENT examples) +add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp) target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt index 7071127e01..6b0c3cef7a 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -3,8 +3,7 @@ set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd") # to be included in "make all/install/check" message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp) -rocm_install(TARGETS ${TILE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples) +add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) @@ -16,8 +15,7 @@ list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-t target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") -add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp) -rocm_install(TARGETS ${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} COMPONENT examples) +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 7d82a16aa9..574edf64d3 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -67,14 +67,13 @@ bool run(const ck_tile::ArgParser& arg_parser) using TypeConfig = AddRmsnormRdquantTypeConfig; - using ADataType = typename TypeConfig::ADataType; - using BDataType = typename TypeConfig::BDataType; - using GammaDataType = typename TypeConfig::GammaDataType; - using XDataType = typename TypeConfig::XDataType; - using UnquantYDataType = ck_tile::null_type; - using YScaleDataType = typename TypeConfig::YScaleDataType; - using QYDataType = typename TypeConfig::QYDataType; - using ComputeDataType = float; + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; // host verify ck_tile::HostTensor a_host({m, n}, {stride, 1}); @@ -89,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); - ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); @@ -193,9 +191,8 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType, - UnquantYDataType>( - x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); + InvRmsDataType>( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); } // yscale diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp old mode 100755 new mode 100644 index 3aab357909..ada4c6f2da --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -62,14 +62,13 @@ bool run(const ck_tile::ArgParser& arg_parser) assert(stride >= n); - using ADataType = DataType; - using BDataType = DataType; - using GammaDataType = DataType; - using XDataType = DataType; - using UnquantYDataType = ck_tile::null_type; - using YScaleDataType = float; - using QYDataType = ck_tile::int8_t; - using ComputeDataType = float; + using ADataType = DataType; + using BDataType = DataType; + using GammaDataType = DataType; + using XDataType = DataType; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; // host verify ck_tile::HostTensor a_host({m, n}, {stride, 1}); @@ -82,7 +81,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor yscale_host_dev({m}, {1}); ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); - ck_tile::HostTensor unquant_y_host_ref({m, n}, {stride, 1}); ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); @@ -195,9 +193,8 @@ bool run(const ck_tile::ArgParser& arg_parser) GammaDataType, ComputeDataType, YDataType, - InvRmsDataType, - UnquantYDataType>( - x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); + InvRmsDataType>( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); } // yscale diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt index daeeb827bd..3849833aca 100644 --- a/example/ck_tile/12_smoothquant/CMakeLists.txt +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -2,8 +2,7 @@ function (add_smoothquant_example TARGET_NAME MAIN_SRC) message("adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} ${MAIN_SRC}) - rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) + add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt index 662e16f0d3..09f3e4ac4e 100644 --- a/example/ck_tile/13_moe_sorting/CMakeLists.txt +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -1,5 +1,4 @@ -add_executable(tile_example_moe_sorting moe_sorting.cpp moe_sorting_api.cpp) -rocm_install(TARGETS tile_example_moe_sorting COMPONENT examples) +add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) diff --git a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt index 9acb27552a..12224a39a2 100644 --- a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt +++ b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt @@ -2,8 +2,7 @@ function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC) message("adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} ${MAIN_SRC}) - rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) + add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt index bb25a55c7d..a716eef19e 100644 --- a/example/ck_tile/15_fused_moe/CMakeLists.txt +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -3,8 +3,7 @@ set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") # to be included in "make all/install/check" message("adding ${TILE_EXAPMLE_FUSED_MOE}") file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_EXAPMLE_FUSED_MOE} main.cpp) -rocm_install(TARGETS ${TILE_EXAPMLE_FUSED_MOE} COMPONENT examples) +add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) diff --git a/example/ck_tile/16_batched_gemm/CMakeLists.txt b/example/ck_tile/16_batched_gemm/CMakeLists.txt index 9eb7a45d80..78e78c6b04 100644 --- a/example/ck_tile/16_batched_gemm/CMakeLists.txt +++ b/example/ck_tile/16_batched_gemm/CMakeLists.txt @@ -1,2 +1 @@ -add_executable(tile_example_batched_gemm batched_gemm.cpp) -rocm_install(TARGETS tile_example_batched_gemm COMPONENT examples) +add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 80d688125b..d34013dd6c 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ -add_executable(tile_example_grouped_gemm grouped_gemm.cpp) -rocm_install(TARGETS tile_example_grouped_gemm COMPONENT examples) +add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) + diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 3a70f0447d..9fbe65e3a7 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,6 +1,4 @@ -add_executable(tile_example_flatmm_basic flatmm_basic.cpp) -rocm_install(TARGETS tile_example_flatmm_basic COMPONENT examples) - +add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt index 10101e4d2e..a08fcebb74 100644 --- a/example/ck_tile/35_batched_transpose/CMakeLists.txt +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -1,9 +1,9 @@ set(TARGET_NAME tile_example_batched_transpose) -add_executable(${TARGET_NAME} batched_transpose_example.cpp batched_transpose_api.cpp) -rocm_install(TARGETS ${TARGET_NAME} COMPONENT examples) +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 16f68c6255..88efe0d8d9 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -14,11 +14,8 @@ add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) add_subdirectory(13_moe_sorting) add_subdirectory(14_moe_smoothquant) +add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) - -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") - add_subdirectory(15_fused_moe) -endif() diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ad6641bc13..611aff318f 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" -#include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 893c9d1ad3..0b38e7789e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -30,7 +30,8 @@ struct GemmPipelineProblemBase using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index ecf861e4e8..a31004b425 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -12,8 +12,7 @@ template + typename CLayout_> struct TileGemmTraits { static constexpr bool kPadM = kPadM_; @@ -28,7 +27,7 @@ struct TileGemmTraits using CLayout = CLayout_; static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; + static constexpr bool UseStructuredSparsity = false; }; template Date: Wed, 30 Apr 2025 17:58:40 -0400 Subject: [PATCH 086/443] updated Doxyfile and added the class list (#2147) * updated Doxyfile and added the class list * Update Doxyfile --- docs/doxygen/Doxyfile | 15 +++---- docs/index.rst | 6 +-- .../Composable-Kernel-API-reference.rst | 42 ------------------- docs/sphinx/_toc.yml.in | 6 +-- 4 files changed, 14 insertions(+), 55 deletions(-) delete mode 100644 docs/reference/Composable-Kernel-API-reference.rst diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index d6f38e0ca9..4367aabc95 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -42,19 +42,19 @@ DOXYFILE_ENCODING = UTF-8 # title of most generated pages and in a few other places. # The default value is: My Project. -PROJECT_NAME = "ck" +PROJECT_NAME = "Composable Kernel" # The PROJECT_NUMBER tag can be used to enter a project or revision number. This # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v3.0.1.0 +PROJECT_NUMBER = # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a # quick idea about the purpose of the project. Keep the description short. -PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HIP" +PROJECT_BRIEF = "Prototype interfaces compatible with ROCm platform and HiP" # With the PROJECT_LOGO tag one can specify a logo or an icon that is included # in the documentation. The maximum height of the logo should not exceed 55 @@ -949,8 +949,8 @@ INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ ../../library/include/ck/library/utility \ - ../../include/ck/wrapper - + ../../include/ck/wrapper \ + ../../include/ck_tile # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses @@ -1161,7 +1161,8 @@ FILTER_SOURCE_PATTERNS = # (index.html). This can be useful if you have a project on for instance GitHub # and want to reuse the introduction page also for the doxygen output. -USE_MDFILE_AS_MAINPAGE = ../../README.md + +USE_MDFILE_AS_MAINPAGE = # The Fortran standard specifies that for fixed formatted Fortran code all # characters from position 72 are to be considered as comment. A common @@ -1370,7 +1371,7 @@ HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css # files will be copied as-is; there are no commands or markers available. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_FILES = +HTML_EXTRA_FILES = ../_doxygen/extra_stylesheet.css # The HTML_COLORSTYLE tag can be used to specify if the generated HTML output # should be rendered with a dark or light theme. diff --git a/docs/index.rst b/docs/index.rst index 6d46eb49b1..4cc26a1d3e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,9 +35,9 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab * :doc:`Composable Kernel supported scalar types <./reference/Composable_Kernel_supported_scalar_types>` * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` - * :ref:`api-reference` - * :ref:`wrapper` - + * :ref:`wrapper` + * :doc:`Composable Kernel complete class list <./doxygen/html/annotated>` + To contribute to the documentation refer to `Contributing to ROCm `_. You can find licensing information on the `Licensing `_ page. diff --git a/docs/reference/Composable-Kernel-API-reference.rst b/docs/reference/Composable-Kernel-API-reference.rst deleted file mode 100644 index b6ee9f7790..0000000000 --- a/docs/reference/Composable-Kernel-API-reference.rst +++ /dev/null @@ -1,42 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _api-reference: - -******************************************************************** -Composable Kernel API reference guide -******************************************************************** - -This document contains details of the APIs for the Composable Kernel library and introduces some of the key design principles that are used to write new classes that extend the functionality of the Composable Kernel library. - -================= -DeviceMem -================= - -.. doxygenstruct:: DeviceMem - -============================= -Kernels For Flashattention -============================= - -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists -the classes that are used in the CK GPU implementation of Flashattention. - -**Gridwise classes** - -.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle - -**Blockwise classes** - -.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 - -.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 - -.. doxygenstruct:: ck::BlockwiseSoftmax - -**Threadwise classes** - -.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic - -.. bibliography:: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index df98998224..2ef3383d84 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -32,10 +32,10 @@ subtrees: title: Composable Kernel custom types - file: reference/Composable_Kernel_vector_utilities.rst title: Composable Kernel vector utilities - - file: reference/Composable-Kernel-API-reference.rst - title: Composable Kernel API reference - file: reference/Composable-Kernel-wrapper.rst - title: Composable Kernel Wrapper + title: Composable Kernel wrapper + - file: doxygen/html/annotated.rst + title: Composable Kernel class list - caption: About entries: From 1d8ef407604882b03857ba75d71be29ccd0ed592 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 30 Apr 2025 18:43:36 -0500 Subject: [PATCH 087/443] Add documentation for ck_tile::array (#2078) * addded documentation for ck_tile::array * clang format fix * spelling errros Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * spelling errros Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: spolifroni-amd * Revert "spelling errros" This reverts commit 4179e7d193e27b0b0b500ad50a87ae9f8dba8334. * Revert "spelling errros" This reverts commit 3f90733dbe27dffb9cb113a007059cf149cafb48. --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: spolifroni-amd Co-authored-by: John Afaganis --- include/ck_tile/core/container/array.hpp | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index fa63597db4..94aa40e278 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -19,6 +19,25 @@ namespace ck_tile { // array buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0}) // use make_array_with({...}) to construct an array with compatible behavior as old ck // TODO: manually added constructor same as old ck +/** + * @brief A fixed-size array container similar to std::array with additional utilities. + * + * This template class provides a lightweight fixed-size array with value semantics, + * supporting both host and device functionality for GPU programming. It includes + * specialized initialization methods and type punning capabilities. + * + * @tparam T_ The type of elements in the array + * @tparam N_ The fixed number of elements in the array + * + * @note This implementation provides additional features beyond std::array: + * - GPU compatibility via CK_TILE_HOST_DEVICE macros + * - Type punning via get_as() and set_as() methods + * - Various specialized access methods + * - Specialized initialization behaviors + * + * The initializer_list constructor fills remaining elements with the last value + * provided if the list size is smaller than N, which is different than std::array. + */ template struct array { @@ -142,6 +161,14 @@ struct array // empty Array +/// @brief Specialization of array container for zero elements. +/// +/// This is a specialization of the array container template for the case where the number of +/// elements is 0. It provides the same interface as the general array template, but with operations +/// appropriate for an empty array. +/// +/// @tparam T The type of elements stored in the array (not used in this specialization but +/// maintained for API consistency). template struct array { From b9d17bdb115c034e9a1028b3adca63762784d9b2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 1 May 2025 07:04:57 -0700 Subject: [PATCH 088/443] add write permissions in workspace (#2154) --- Jenkinsfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Jenkinsfile b/Jenkinsfile index 3e22eb2f01..68999d8aa6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -76,6 +76,7 @@ def check_host() { if ("${env.CK_SCCACHE}" != "null"){ def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" echo "sccache server: ${SCCACHE_SERVER}" + sh "chmod +w -R ${env.WORKSPACE}" sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" From 79b0bfeb41db45de0cb65fdf24d27201ea0ae0e6 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Thu, 1 May 2025 11:55:48 -0600 Subject: [PATCH 089/443] MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152) * Add gemm_mx_fp8_bf8 example with row-major B * Add more overloads of MX MFMA instructions * Add MK_KN (RRR) tests * Add KM_NK (CCR) tests * Add more problem sizes to Large tests * Add test_gemm_mx to the list of regression tests --- example/67_gemm_microscaling/CMakeLists.txt | 3 + .../67_gemm_microscaling/gemm_mx_common.hpp | 2 +- .../67_gemm_microscaling/gemm_mx_fp8_bf8.cpp | 97 ++++++++++ .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 14 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 18 ++ include/ck/utility/amd_xdlops.hpp | 101 +++++++++- .../tensor_operation_instance/gpu/gemm_mx.hpp | 50 ++++- .../gpu/gemm_mx/CMakeLists.txt | 6 +- ...device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp | 61 ++++++ ...l_bf8_f8_f16_mk_kn_mn_default_instance.cpp | 32 ++++ ...device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp | 62 ++++++ ...l_f8_f8_bf16_km_nk_mn_default_instance.cpp | 32 ++++ test/CMakeLists.txt | 1 + test/gemm_mx/test_gemm_mx.cpp | 179 +++++++++++++++++- test/gemm_mx/test_gemm_mx_util.hpp | 2 +- 15 files changed, 642 insertions(+), 18 deletions(-) create mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 34125465a9..1a1db51c37 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -6,3 +6,6 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) +add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) + diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 32ef975192..99ed2a23b9 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -235,7 +235,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c break; case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp new file mode 100644 index 0000000000..ce4ebc0a40 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 256, // MPerBlock + 256, // NPerBlock + 128, // KPerBlock + 16, // AK1 + 8, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 44d515e76c..1154fa2aa3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } else // RowMajor B { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 08c4e4ba6e..06268f3cfb 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1129,6 +1129,12 @@ struct MfmaSelector return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> constexpr auto GetMfma() { @@ -1147,6 +1153,18 @@ struct MfmaSelector return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a8c3baa31b..71e1937a23 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -532,7 +532,44 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, // OPSEL scale_a, @@ -576,7 +613,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, // OPSEL scale_a, @@ -605,7 +642,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 1, // cbsz + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 1, // blgp 0, // OPSEL scale_a, @@ -617,6 +654,64 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_b; ignore = scale_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; #endif } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp index 1c40ccec5d..4af5143f45 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -45,6 +45,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( PassThrough, PassThrough>>>& instances); +void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( + std::vector>>& instances); + template && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs); + add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(op_ptrs); } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt index a166fc4ce4..0442bed130 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -1,14 +1,18 @@ # ONLY MX_KERNELS set(GEMM_MX_INSTANCES) -list(APPEND GEMM_MX_INSTANCES +list(APPEND GEMM_MX_INSTANCES device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp + device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp ) set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..25dd68a207 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF8 = bf8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp new file mode 100644 index 0000000000..2b6ccdbeda --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp new file mode 100644 index 0000000000..0df018bf1d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..c75e779fea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 72c51823be..6bde1140d9 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,7 @@ set(REGRESSION_TESTS test_batchnorm_bwd_rank_4 test_grouped_convnd_bwd_data_xdl test_conv_tensor_rearrange + test_gemm_mx ) function(add_test_executable TEST_NAME) diff --git a/test/gemm_mx/test_gemm_mx.cpp b/test/gemm_mx/test_gemm_mx.cpp index 6e1957e60a..2c976a217f 100644 --- a/test/gemm_mx/test_gemm_mx.cpp +++ b/test/gemm_mx/test_gemm_mx.cpp @@ -39,17 +39,49 @@ class TestGemmMX_MK_NK { }; +template +class TestGemmMX_MK_KN + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + +template +class TestGemmMX_KM_NK + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + // clang-format off -using KernelTypes_MK_NK = ::testing::Types< +using KernelTypes_F8_MK_NK = ::testing::Types< #if defined(CK_ENABLE_FP8) // ADataType, BDataType, CDataType, ScaleBlockSize std::tuple< F8, F8, F16, ck::Number<32> >, std::tuple< F8, F8, BF16, ck::Number<32> > #endif >; + +using KernelTypes_BF8_F8_MK_KN = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< BF8, F8, F16, ck::Number<32> > +#endif + >; + +using KernelTypes_F8_KM_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< F8, F8, BF16, ck::Number<32> > +#endif + >; // clang-format on -TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_F8_MK_NK); +TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_BF8_F8_MK_KN); +TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_F8_KM_NK); + +/// A: RowMajor +/// B: ColMajor +/// C: RowMajor TYPED_TEST(TestGemmMX_MK_NK, SmallM) { @@ -95,14 +127,151 @@ TYPED_TEST(TestGemmMX_MK_NK, Regular) TYPED_TEST(TestGemmMX_MK_NK, Large) { - std::vector Ms{4096}; - constexpr int N = 3840; - constexpr int K = 4096; + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + constexpr int K = 4096; constexpr int StrideA = K; constexpr int StrideB = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +/// A: RowMajor +/// B: RowMajor +/// C: RowMajor + +TYPED_TEST(TestGemmMX_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; constexpr int StrideC = N; for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } + +TYPED_TEST(TestGemmMX_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_KN, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_KN, Large) +{ + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + + constexpr int K = 4096; + constexpr int StrideA = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideB = N; + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +/// A: ColMajor +/// B: ColMajor +/// C: RowMajor + +TYPED_TEST(TestGemmMX_KM_NK, SmallN) +{ + constexpr int M = 256; + std::vector Ns{1, 2, 3, 4, 5, 6}; + constexpr int K = 512; + + constexpr int StrideA = M; + constexpr int StrideB = K; + + for(int N : Ns) + { + const auto new_N = N * 8; + const auto StrideC = new_N; + this->Run(M, new_N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmMX_KM_NK, MidLargeN) +{ + constexpr int M = 256; + std::vector Ns{127, 255, 312, 799, 1573}; + constexpr int K = 512; + + constexpr int StrideA = M; + constexpr int StrideB = K; + + for(int N : Ns) + { + const auto new_N = (N + 7) / 8 * 8; + const auto StrideC = new_N; + this->Run(M, new_N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmMX_KM_NK, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_KM_NK, Large) +{ + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + + constexpr int K = 4096; + constexpr int StrideB = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideA = M; + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} diff --git a/test/gemm_mx/test_gemm_mx_util.hpp b/test/gemm_mx/test_gemm_mx_util.hpp index 3bca4ceded..02833daeb4 100644 --- a/test/gemm_mx/test_gemm_mx_util.hpp +++ b/test/gemm_mx/test_gemm_mx_util.hpp @@ -150,7 +150,7 @@ bool profile_gemm_mx_impl(int do_verification, break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_m_k_scale.GenerateTensorValue( GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1] From 619fba3134641e4a08950a3bea385c16dbb74b64 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 1 May 2025 12:37:27 -0700 Subject: [PATCH 090/443] re-enable ck4inductor tests by default (#2155) --- Jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 68999d8aa6..a9d30d9f71 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -845,8 +845,8 @@ pipeline { description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") booleanParam( name: "RUN_INDUCTOR_TESTS", - defaultValue: false, - description: "Run inductor codegen tests (default: OFF)") + defaultValue: true, + description: "Run inductor codegen tests (default: ON)") } environment{ dbuser = "${dbuser}" From d58f2b8bd0c2adad65a731403673d545d8483acb Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 1 May 2025 13:36:24 -0700 Subject: [PATCH 091/443] mfma_32x32x64_fp8/bf8 (#2148) * support for mfma_32x32x64_fp8 * clang-formatted * Fixing sparsity in codegen --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 +++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 98 +++++++++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 5 + tile_engine/ops/gemm/gemm_instance_builder.py | 54 +++++----- 4 files changed, 147 insertions(+), 22 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 22962b9404..e75aca1d91 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -228,6 +228,18 @@ using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cd32f35180..96c3c3d29f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1440,6 +1440,104 @@ template using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 64; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + // int8 template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 0e3342c479..64bd61a3dc 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -74,6 +74,11 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; + // clang-format on } // namespace impl diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index b441bdd2d6..a748c35feb 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -282,14 +282,14 @@ class GemmCodeGenerator: def _generate_common_header(self): """Generate common header with datatypes and layout""" - ctype = self.config.datatype - atype = self.config.datatype - btype = self.config.datatype + self.ctype = self.config.datatype + self.atype = self.config.datatype + self.btype = self.config.datatype if self.config.datatype in ['fp8', 'bf8']: - ctype = 'fp16' + self.ctype = 'fp16' elif self.config.datatype in ['int4']: - atype = 'fp16' - ctype = 'fp16' + self.atype = 'fp16' + self.ctype = 'fp16' content = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -298,10 +298,10 @@ class GemmCodeGenerator: #include "ck_tile/core.hpp" // Data types -using ADataType = {DATA_TYPE_MAP[atype]}; -using BDataType = {DATA_TYPE_MAP[btype]}; +using ADataType = {DATA_TYPE_MAP[self.atype]}; +using BDataType = {DATA_TYPE_MAP[self.btype]}; using AccDataType = float; -using CDataType = {DATA_TYPE_MAP[ctype]}; +using CDataType = {DATA_TYPE_MAP[self.ctype]}; // Layout configurations using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; @@ -499,7 +499,7 @@ struct GemmDispatcher { static void init(bool structured_sparsity) { auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; + if(!kernel_map.empty()) return; \n""" # Add tile/warp instantiations tile_params = set(itertools.product( @@ -516,12 +516,25 @@ struct GemmDispatcher { for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& stream) {{ - """ + content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) {{ + if(structured_sparsity){{ // SMFMA""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + content += f""" + }} else {{""" for tile in tile_params: # Check if we have valid tile/warp combinations # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m @@ -529,13 +542,10 @@ struct GemmDispatcher { ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): continue content += f""" - if(structured_sparsity) {{ - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {1}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); - }} else {{ - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {0}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); - }}""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" - }};\n""" + }} + }};\n""" content += """ } From c4e4e592c13168a9cf053039a447b31714b92c55 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 May 2025 07:29:07 -0700 Subject: [PATCH 092/443] Bump rocm-docs-core[api_reference] from 1.18.2 to 1.18.4 in /docs/sphinx (#2161) Bumps [rocm-docs-core[api_reference]](https://github.com/ROCm/rocm-docs-core) from 1.18.2 to 1.18.4. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.18.2...v1.18.4) --- updated-dependencies: - dependency-name: rocm-docs-core[api_reference] dependency-version: 1.18.4 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index ac03e40939..6c48b2de09 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.18.2 +rocm-docs-core[api_reference]==1.18.4 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 3742eeebba..62c3ea8ff8 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.18.2 +rocm-docs-core[api-reference]==1.18.4 # via -r requirements.in rpds-py==0.24.0 # via From 79beaacdd17928d77d6622498b734cd2d3d3c6d6 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 5 May 2025 09:18:22 -0600 Subject: [PATCH 093/443] Restrict MX GEMM instantiation to GFX950 arch (#2157) --- .../device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp | 2 +- .../device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp | 2 +- .../device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 2 +- .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp index 25dd68a207..3713ebae0e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp @@ -45,7 +45,7 @@ using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) +#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp index 0df018bf1d..5b0c5137b3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) +#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp index 1e979f69ca..8e25bcc25f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) +#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp index 0ca4f2a3ce..5fefb57257 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) || defined(CK_USE_NATIVE_MX_SUPPORT) +#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, From 0bcb804ad079f8b427786cc701675b3c535a180b Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 5 May 2025 18:46:44 +0200 Subject: [PATCH 094/443] [CK_TILE] Remove scratch usage from universal gemm (#2001) * moves kbatch condition outside of kernel * add reviewer comments * fixes * fix tests * fixes after review --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- example/ck_tile/03_gemm/gemm_basic.cpp | 91 ++++++---- example/ck_tile/03_gemm/universal_gemm.cpp | 88 ++++++--- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 171 ++++++++++-------- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 171 ++++++++++-------- .../ops/epilogue/cshuffle_epilogue.hpp | 63 +++---- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 10 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 53 ++---- .../batched_gemm/test_batched_gemm_util.hpp | 34 +++- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 98 ++++++---- .../grouped_gemm/test_grouped_gemm_util.hpp | 34 +++- 10 files changed, 473 insertions(+), 340 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 69051423fb..1edb3da947 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -53,50 +53,67 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - if(!Kernel::IsSupportedArgument(kargs)) + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(ck_tile::integral_constant{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 2ba16ca89d..e6a2811918 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -61,10 +61,13 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -116,23 +120,40 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -146,20 +167,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -167,7 +189,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -175,7 +198,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -183,7 +207,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -191,7 +216,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -199,20 +225,22 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -220,18 +248,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index a0cd18ec74..0219c67305 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -106,61 +106,81 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; if(has_hot_loop) @@ -168,18 +188,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -193,20 +213,21 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -214,7 +235,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -222,7 +244,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -230,7 +253,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -238,7 +262,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -246,20 +271,22 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -267,18 +294,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } std::ostringstream err; err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 2a9903362d..9b134ff779 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -114,66 +114,86 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); - return ave_time; }; if(has_hot_loop) @@ -181,18 +201,18 @@ float grouped_gemm(const std::vector& gemm_descs, #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -206,20 +226,21 @@ float grouped_gemm(const std::vector& gemm_descs, // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -227,7 +248,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -235,7 +257,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -243,7 +266,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -251,7 +275,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -259,20 +284,22 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 225997439e..9b8dde1905 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -22,23 +22,25 @@ template + bool isCTransposed_, + memory_operation_enum MemoryOperation_> struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t kMWave = kMWave_; - static constexpr index_t kNWave = kNWave_; - static constexpr index_t kMPerXdl = kMPerXdl_; - static constexpr index_t kNPerXdl = kNPerXdl_; - static constexpr index_t kKPerXdl = kKPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t kMWave = kMWave_; + static constexpr index_t kNWave = kNWave_; + static constexpr index_t kMPerXdl = kMPerXdl_; + static constexpr index_t kNPerXdl = kNPerXdl_; + static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; template @@ -52,18 +54,19 @@ struct CShuffleEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t kMWave = Problem::kMWave; - static constexpr index_t kNWave = Problem::kNWave; - static constexpr index_t kMPerXdl = Problem::kMPerXdl; - static constexpr index_t kNPerXdl = Problem::kNPerXdl; - static constexpr index_t kKPerXdl = Problem::kKPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr index_t kMPerIteration = kMPerXdl * kMWave; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave; + using CLayout = remove_cvref_t; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t kMWave = Problem::kMWave; + static constexpr index_t kNWave = Problem::kNWave; + static constexpr index_t kMPerXdl = Problem::kMPerXdl; + static constexpr index_t kNPerXdl = Problem::kNPerXdl; + static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr index_t kMPerIteration = kMPerXdl * kMWave; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave; using WG = WarpGemmMfmaDispatcher + template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { @@ -179,7 +180,7 @@ struct CShuffleEpilogue const auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index dfb6bfae58..d495c0d950 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -142,15 +142,7 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } - else - { - this->template RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } + this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index bc41f680f2..9c25104cd7 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -608,9 +608,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -622,7 +620,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -640,9 +639,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } /** @@ -660,9 +658,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -675,7 +671,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -692,9 +689,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const @@ -718,7 +714,9 @@ struct GemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, @@ -730,38 +728,15 @@ struct GemmKernel i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } } else { - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } } } }; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 0af3ef3b34..4633f23ded 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -81,10 +81,13 @@ class TestCkTileBatchedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -138,11 +142,29 @@ class TestCkTileBatchedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 1b997ddbce..0329f16416 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -138,9 +138,12 @@ class TestCkTileGemmPipeline : public ::testing::Test const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -193,15 +197,32 @@ class TestCkTileGemmPipeline : public ::testing::Test s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if constexpr(PipelineType == GemmPipelineType::CompV3) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -219,69 +240,69 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 3) { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 4) { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 5) { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 6) { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } if constexpr(BaseGemmPipeline::PrefetchStages > 7) { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } } @@ -290,15 +311,15 @@ class TestCkTileGemmPipeline : public ::testing::Test { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } } @@ -307,7 +328,8 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail number always Full - #PrefetchStages if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index b125d19762..3dec229643 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -102,10 +102,13 @@ class TestCkTileGroupedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); @@ -164,11 +168,29 @@ class TestCkTileGroupedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else From b8fa27bfef7b1d2df3984e1fd01e9c5df72f8b33 Mon Sep 17 00:00:00 2001 From: Muhammed Emin Ozturk Date: Mon, 5 May 2025 13:12:22 -0700 Subject: [PATCH 095/443] Fix failure in test_batched_gemm_softmax_gemm_permute for lower resource devices (#2117) * Problematic test case are analyzed and turned off for lower resource GPUs * update device info * Update test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp * Update test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp * Update test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp Co-authored-by: John Afaganis --- ...ed_gemm_bias_softmax_gemm_permute_util.hpp | 2 + .../test_batched_gemm_device_utils.hpp | 67 ++++++++++++++ ...hed_gemm_softmax_gemm_permute_bf16_xdl.cpp | 87 +++++++++++++++---- ...hed_gemm_softmax_gemm_permute_fp16_xdl.cpp | 11 +++ 4 files changed, 150 insertions(+), 17 deletions(-) create mode 100644 test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp index d7c39367c8..1464eacfa5 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -9,6 +9,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp" +#include + using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::MaskingSpecialization; using ck::tensor_operation::device::TensorSpecialization; diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp new file mode 100644 index 0000000000..7d20ee4827 --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +namespace ck { +namespace test { + +struct DeviceResources +{ + int computeUnits; + size_t totalMemory; + std::string deviceName; + // Add other relevant properties as needed +}; + +inline DeviceResources GetDeviceResources() +{ + DeviceResources res; + hipDeviceProp_t props; + + hipError_t status = hipGetDeviceProperties(&props, 0); + if(status != hipSuccess) + { + props.multiProcessorCount = 0; + res.computeUnits = 0; + res.totalMemory = 0; + res.deviceName = "Unknown"; + return res; + } + + res.computeUnits = props.multiProcessorCount; + res.totalMemory = props.totalGlobalMem; + res.deviceName = props.name; + + return res; +} + +// Device capability tiers +enum class DeviceCapabilityTier +{ + LOW, // Low resources devices (CU less than 80) + MEDIUM, // Mid-range devices + HIGH // High resources devices (CU hiher than 100) +}; + +inline DeviceCapabilityTier DetermineDeviceTier() +{ + DeviceResources res = GetDeviceResources(); + + // Adjust these thresholds based on your device specifics + if(res.computeUnits < 80) + { + return DeviceCapabilityTier::LOW; + } + else if(res.computeUnits < 100) + { + return DeviceCapabilityTier::MEDIUM; + } + else + { + return DeviceCapabilityTier::HIGH; + } +} + +} // namespace test +} // namespace ck diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp index 8136257a24..8d894576c4 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" +#include "test_batched_gemm_device_utils.hpp" template class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16 @@ -110,14 +111,45 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16_Irregul TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16) { - this->lengths_ = std::vector>{ - {256, 256, 64, 64, 48, 16}, - {256, 256, 128, 128, 48, 16}, - {512, 512, 64, 64, 48, 16}, - {512, 512, 128, 128, 48, 16}, - {1024, 1024, 64, 64, 48, 16}, - {1024, 1024, 128, 128, 48, 16}, - }; + + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + + // Configure test sizes based on device tier + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + // Minimal test sizes for low resource devices + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 16, 8}, {256, 256, 128, 128, 16, 8}, {512, 512, 64, 64, 8, 4}}; + std::cout << "Running reduced benchmarks for low-resource device" << std::endl; + } + else if(deviceTier == ck::test::DeviceCapabilityTier::MEDIUM) + { + // Medium test sizes + this->lengths_ = std::vector>{{256, 256, 64, 64, 24, 12}, + {256, 256, 128, 128, 24, 12}, + {512, 512, 64, 64, 16, 8}, + {512, 512, 128, 128, 16, 8}, + {1024, 1024, 64, 64, 8, 4}, + {1024, 1024, 128, 128, 8, 4}}; + std::cout << "Running medium benchmarks for mid-tier device" << std::endl; + } + else + { + // Full test sizes for high resource devices + this->lengths_ = std::vector>{{256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}}; + std::cout << "Running full benchmarks for high-performance device" << std::endl; + } + this->bench_ = true; this->verify_ = false; this->Run(); @@ -127,9 +159,20 @@ using ck::tensor_operation::device::GemmSpecialization; TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) { + + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + int P = 120; // requires padding int Q = 128; // do not require padding + // For lower-end devices, we might need to skip some tests + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + std::cout << "Skipping GemmSpecialization tests for low-resource device" << std::endl; + return; + } + // IsSupported(M, N, K, O) // clang-format off EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); @@ -153,15 +196,25 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) { - // IsSupported(M, N, K, O) - // clang-format off - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); - // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); - // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + EXPECT_FALSE( + DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{} + .IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKPadding>{} + .IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % + // ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % + // B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 128, 129)); // clang-format on } diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp index 81d404109f..3a86736f44 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" +#include "test_batched_gemm_device_utils.hpp" template class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 @@ -132,9 +133,19 @@ using ck::tensor_operation::device::GemmSpecialization; TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) { + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + int P = 120; // requires padding int Q = 128; // do not require padding + // For lower-end devices, we might need to skip some tests + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + std::cout << "Skipping GemmSpecialization tests for low-resource device" << std::endl; + return; + } + // IsSupported(M, N, K, O) // clang-format off EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); From 4e9b76f88c572a6c54f34cc6467b96279c0e86e4 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 6 May 2025 17:32:07 +0800 Subject: [PATCH 096/443] [CK_TILE] optimize moe sorting kernel, boost large context case up to 20x (#2153) * combine 2-3 as single stage * support zeroing * improve long tokens * update specialization * b16 ws * 8bit topk optimize * update 15 example --- .../ck_tile/13_moe_sorting/moe_sorting.cpp | 3 +- .../13_moe_sorting/moe_sorting_api.cpp | 225 +++-- .../13_moe_sorting/moe_sorting_api.hpp | 2 +- .../13_moe_sorting/script/smoke_test.sh | 6 + example/ck_tile/15_fused_moe/fused_moe.hpp | 2 + .../ck_tile/15_fused_moe/fused_moesorting.hpp | 1 + .../15_fused_moe/instances/fused_moe_api.cpp | 6 + .../instances/fused_moesorting_api.cpp | 208 ++++- example/ck_tile/15_fused_moe/main.cpp | 3 +- include/ck_tile/core.hpp | 1 + include/ck_tile/core/arch/arch.hpp | 9 + .../ck_tile/core/arch/workgroup_barrier.hpp | 65 ++ include/ck_tile/core/config.hpp | 2 +- .../fused_moe/kernel/moe_sorting_kernel.hpp | 789 +++++++++++++++++- .../fused_moe/kernel/moe_sorting_problem.hpp | 9 +- 15 files changed, 1216 insertions(+), 115 deletions(-) create mode 100644 include/ck_tile/core/arch/workgroup_barrier.hpp diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index e59fcaedad..ce689a370c 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -153,9 +153,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts); + ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); - if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 109ec1b157..305cf118d2 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -153,7 +161,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) { return moe_sorting_mp(t, a, s); } @@ -171,57 +179,107 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ - }() +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { @@ -230,29 +288,74 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(t.local_expert_masking) + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, true), - MOE_SORTING_MP_1(1, true), - MOE_SORTING_MP_2(1, true), - MOE_SORTING_MP_3(1, true)); - return ave_time; +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif } else { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, false), - MOE_SORTING_MP_1(1, false), - MOE_SORTING_MP_2(1, false), - MOE_SORTING_MP_3(1, false)); - return ave_time; + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } return -1; } -int moe_sorting_get_workspace_size(int tokens, int num_experts) +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts); + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index b47ae9013b..0fe8d81e70 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -22,6 +22,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs // if return non zero, means need workspace, you need to allocate a GPU buffer // and set to moe_sorting_args.p_ws // NOTE: workspace size are required to clear zero before use the API -int moe_sorting_get_workspace_size(int tokens, int num_experts); +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index cf2c2e164b..fbfb10822c 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -26,3 +26,9 @@ $EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 $EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 $EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 $EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 +$EXE -t=128 -e=128 -k=6 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840 +$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840 \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index b354d1d347..46425384cc 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -56,4 +56,6 @@ struct fused_moe_traits bool local_expert_masking; // if mask experts as local expert }; +// if return zero, no ws needed +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp index a3ff8c5bf7..11e1c6e531 100644 --- a/example/ck_tile/15_fused_moe/fused_moesorting.hpp +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -18,4 +18,5 @@ struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs { }; +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index f887d57aa9..b3515b1bec 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -2,6 +2,12 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "fused_moe.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); +} float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) { diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7aedaa9317..0d83c48d02 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -107,6 +115,10 @@ } #endif +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s); + float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -153,18 +165,198 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til } } #else - using index_t = ck_tile::index_t; - using ms_weight_type = float; - auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); - auto sub_token_ = r_ - 2; - r_ = (r_ - 2) / 8; - bool is_sub_token_onshot = a.tokens <= sub_token_; + if(fused_moe_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) + { + return fused_moesorting_mp(t, a, s); + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts); + auto row_ = sub_token_ / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; - (void)c_; - MOE_SORTING_DISPATCH_EMASK_(r_); + MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); #endif } return -1; } + +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + }() + +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } + +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + using ms_index_t = ck_tile::index_t; + using ms_weight_type = float; + + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } + } + } + return -1; +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index cb93ce8907..da843891ce 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.get_element_space_size_in_bytes()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); + ck_tile::index_t workspace_size = + ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 821b3a8e84..b94157eaec 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/arch/workgroup_barrier.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/container_helper.hpp" diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 09de5f325f..1d3cf5c010 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -154,4 +154,13 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres #pragma clang diagnostic pop } +CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() +{ +#if defined(__gfx950__) + return 163840; +#else + return 65536; +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/workgroup_barrier.hpp b/include/ck_tile/core/arch/workgroup_barrier.hpp new file mode 100644 index 0000000000..827a490fcb --- /dev/null +++ b/include/ck_tile/core/arch/workgroup_barrier.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +struct workgroup_barrier +{ + CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {} + + CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0) + { + return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED); + } + + CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) != value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) < value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(atomicCAS(base_ptr + offset, compare, value) != compare) {} + } + __syncthreads(); + } + + // enter critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); } + + // exit critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); } + + CK_TILE_DEVICE void inc(uint32_t offset = 0) + { + __syncthreads(); + if(threadIdx.x == 0) + { + atomicAdd(base_ptr + offset, 1); + } + } + + uint32_t* base_ptr; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 414509e479..27133fa847 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -257,5 +257,5 @@ #endif #ifndef CK_TILE_WA_ISSUE_2028 -#define CK_TILE_WA_ISSUE_2028 1 +#define CK_TILE_WA_ISSUE_2028 0 #endif diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 6a7ccd2472..664294fe18 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -19,6 +19,10 @@ namespace ck_tile { #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_FUSE_MP_01 +#define MOE_SORTING_FUSE_MP_01 0 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -118,7 +122,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here int smem_rows = [&](){ index_t target_occupancy_ = 2; - constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t total_ = get_smem_capacity() / sizeof(index_t); constexpr index_t sub_unroll = 8; constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt // at lease 2 lines, one for sub_token unroll, one for cumsum @@ -250,7 +254,7 @@ struct MoeSortingKernel { #if MOE_SORTING_USE_EX_KERNEL auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); - return smem_rows * smem_cols * sizeof(int); + return smem_rows * smem_cols * sizeof(index_t); #else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size @@ -1063,17 +1067,43 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) return (tokens + chunk - 1) / chunk * chunk; }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts) +// 4-i32 mesh, 2-i16 mseh, 1-i8 mesh +CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_, + index_t /*num_experts_*/, + index_t topk_) +{ + // small token case, let's run mesh with dword score board + if(tokens_ < 512) + return 4; + else + { + if(topk_ >= 255) + return 2; // 16bit mesh + else + return 1; // 8bit mesh if small enough + } +} + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens, + index_t num_experts, + index_t topk) { index_t row_size = moe_sorting_mp_mesh_stride(tokens); - return num_experts * row_size; + index_t elem = num_experts * row_size; + return elem * moe_sorting_mesh_byte_size(tokens, num_experts, topk); }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts) +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts) { constexpr index_t chunk = 32; index_t row_size = num_experts + 1; - return (row_size + chunk - 1) / chunk * chunk; + return (row_size + chunk - 1) / chunk * chunk * sizeof(index_t); +}; + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() +{ + constexpr index_t chunk = 32; + return chunk * sizeof(index_t); }; template @@ -1245,15 +1275,20 @@ CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) } // return size in byte -CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_) { - index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) + - impl::moe_sorting_mp_cumsum_elem(num_experts_); - return elem * sizeof(index_t); + index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) + + impl::moe_sorting_mp_cumsum_smem_size(num_experts_) +#if MOE_SORTING_FUSE_MP_01 + + impl::moe_sorting_mp_sem_smem_size(); +#else + ; +#endif + return s_; } // return size in byte -CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_) { #if 1 if(moe_sorting_is_oneshot(tokens_, num_experts_)) @@ -1262,10 +1297,10 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts } else { - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); } #else - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); #endif } @@ -1320,6 +1355,7 @@ struct MoeSortingMultiPhaseKernel_P0 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1371,22 +1407,21 @@ struct MoeSortingMultiPhaseKernel_P0 { using topk_id_t = ext_vector_t; - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); - IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x) + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += gridDim.x * BLOCK_SIZE) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; }); } } @@ -1400,6 +1435,7 @@ struct MoeSortingMultiPhaseKernel_P1 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1420,9 +1456,9 @@ struct MoeSortingMultiPhaseKernel_P1 Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); return k; @@ -1444,13 +1480,11 @@ struct MoeSortingMultiPhaseKernel_P1 int eid = blockIdx.x; - constexpr index_t index_pack = 4; // always packed - using r_t = ext_vector_t; // always use int32x4 + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 r_t* p_expert_mesh = reinterpret_cast( - reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); @@ -1502,6 +1536,197 @@ struct MoeSortingMultiPhaseKernel_P1 } }; +#if MOE_SORTING_FUSE_MP_01 +template +struct MoeSortingMultiPhaseKernel_P01 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_expert_sem; // [1] + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t wg_count; // used for semaphore + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_expert_sem = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) + + impl::moe_sorting_mp_cumsum_smem_size(h.num_experts)); + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.wg_count = WGCounts(h); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h) + { + index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile; + index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // no more than grid_size + return min(elem_cnt, GridSize(h)); + } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() + { + return BLOCK_SIZE / warpSize * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + workgroup_barrier wb{reinterpret_cast(kargs.p_expert_sem)}; + + { + using topk_id_t = ext_vector_t; + + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + +#pragma unroll Problem::SubTokenTile + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += BLOCK_SIZE * gridDim.x) + { + auto x = p_topk_ids[i]; + static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { + IndexType eid = x[j.value]; // ext_vector_type must use int to [] + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod( + i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + }); + } + if(static_cast(blockIdx.x) < kargs.wg_count) + { + wb.inc(); + } + } + + { + __shared__ char smem[GetSmemSize()]; + int eid = blockIdx.x; + + // early exist in case of extra atomic wait + if(eid >= kargs.num_experts) + return; + + wb.wait_lt(kargs.wg_count); + + for(; eid < kargs.num_experts; eid += gridDim.x) + { + // if(threadIdx.x == 0) + // printf("!!! bid:%d, eid:%d (%d, %d)\n", + // static_cast(blockIdx.x), + // eid, + // kargs.num_experts, + // static_cast(blockDim.x)); + constexpr index_t index_pack = 4; // always packed + using r_t = ext_vector_t; // always use int32x4 + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + + if constexpr(Problem::LocalExpertMasking) + { + IndexType mask = p_local_expert_mask[eid]; + if(mask == 0) + continue; // skip + } + + index_t cnt = 0; // per-wave cnt + for(int i = 0; i < loops; i++) + { + int position = i * BLOCK_SIZE + threadIdx.x; + r_t v{0}; + if(position < (kargs.mesh_stride / index_pack)) + v = p_expert_mesh[position]; + index_t local_sum = 0; + static_for<0, index_pack, 1>{}( + [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; }); + cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); + } + + index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / warpSize; + + // reduce cross wave + IndexType* s = reinterpret_cast(smem); + __syncthreads(); + if(lane_id == 0) + { + s[wave_id] = cnt; + } + __syncthreads(); + + if(threadIdx.x == 0) + { + index_t c = 0; + for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + { + c += s[i]; + } + p_expert_cumsum[eid] = c; + } + } + } + } +}; +#endif + // token count cumsum template struct MoeSortingMultiPhaseKernel_P2 @@ -1510,6 +1735,7 @@ struct MoeSortingMultiPhaseKernel_P2 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1536,10 +1762,9 @@ struct MoeSortingMultiPhaseKernel_P2 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; - // k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -1566,7 +1791,8 @@ struct MoeSortingMultiPhaseKernel_P2 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return 2 * BLOCK_SIZE * sizeof(IndexType); + // return 2 * BLOCK_SIZE * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1718,6 +1944,7 @@ struct MoeSortingMultiPhaseKernel_P3 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1749,9 +1976,9 @@ struct MoeSortingMultiPhaseKernel_P3 k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.tokens = h.tokens; k.num_experts = h.num_experts; k.topk_mdiv = mdiv{static_cast(h.topk)}; @@ -1782,9 +2009,6 @@ struct MoeSortingMultiPhaseKernel_P3 const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - int eid = blockIdx.x; int wave_id = threadIdx.x / warpSize; int lane_id = threadIdx.x % warpSize; @@ -1866,6 +2090,495 @@ struct MoeSortingMultiPhaseKernel_P3 } }; +namespace impl { +// we use dynamic LDS size here +CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) +{ + constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 + const index_t expert_cumsum_elem = num_experts_ + 1; + return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int); +} +} // namespace impl + +// token count cumsum +template +struct MoeSortingMultiPhaseKernel_P23 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + struct Kargs + { + const void* p_weights; + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_total_tokens_post_pad; // [1] + void* p_sorted_expert_ids; + + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_moe_buf; + + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv unit_size_mdiv; + mdiv topk_mdiv; + long_index_t moe_buf_bytes; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + + k.p_moe_buf = h.p_moe_buf; + + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + + k.moe_buf_bytes = h.moe_buf_bytes; + + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // use 1 block to cumsum + // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // only use this at host ! + CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) + { + const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts); + const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType); + return max(smem_23, smem_sf); + } + + // reduce single pixel within a wave + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(static_cast(blockIdx.x) >= kargs.num_experts) + { + impl::moe_buf_set_zero_kernel( + reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes, + blockIdx.x - kargs.num_experts); + return; + } + + extern __shared__ char smem[]; + { + IndexType* s = reinterpret_cast(smem); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_total_tokens_post_pad = + reinterpret_cast(kargs.p_total_tokens_post_pad); + IndexType* p_sorted_expert_ids = + reinterpret_cast(kargs.p_sorted_expert_ids); + + const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % warpSize; + + IndexType prev_cumsum_a = 0; + IndexType prev_cumsum_b = 0; + + for(index_t i = 0; i < loops; i++) + { + index_t position = i * BLOCK_SIZE + threadIdx.x; + IndexType a_ = 0; // token count for a expert + IndexType b_ = 0; // mask for a expert + if(position < kargs.num_experts) + { + a_ = p_expert_cumsum[position]; + if constexpr(Problem::LocalExpertMasking) + b_ = p_local_expert_mask[position]; + } + + int blocks_pers_expert = + kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1); + // pad token + int padded_blocks_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1); + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return b_ ? x_ : 0; + } + else + return x_; + }(); + + IndexType cumsum_a = padded_blocks_per_expert; + IndexType cumsum_b = b_; + + // Note: we first cumsum local round, then add previous cumsum + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + } + + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev_a = s[4 + i_w]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + prev_a = wave_id > i_w ? prev_a : 0; // mask out + prev_b = wave_id > i_w ? prev_b : 0; // mask out + cumsum_a += prev_a; + cumsum_b += prev_b; + }); + + // Now let's add previous cumsum + cumsum_a += prev_cumsum_a; + cumsum_b += prev_cumsum_b; + + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[2] = cumsum_a; // store the last cumsum + s[3] = cumsum_b; + } + + IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt + IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt + + __syncthreads(); + prev_cumsum_a = s[2]; + prev_cumsum_b = s[3]; + + if(position < kargs.num_experts) + { + p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor; + } + + { + if(blockIdx.x == 0) + { + if constexpr(Problem::LocalExpertMasking) + { + if(b_) + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = out_1; + } + } + } + else + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = position; + } + } + } + } + } + + if(threadIdx.x == 0) + { + auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; + if(blockIdx.x == 0) + p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad; + } + } + + __syncthreads(); + + { + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* s = reinterpret_cast(smem); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + const WeightType* p_weights = static_cast(kargs.p_weights); + WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); + + int eid = blockIdx.x; + int wave_id = threadIdx.x / warpSize; + int lane_id = threadIdx.x % warpSize; + int e_start = p_expert_cumsum_smem[eid]; + int e_end = p_expert_cumsum_smem[eid + 1]; + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) + return; + } + + if constexpr(Problem::LocalExpertMasking) + { + int e_mask = p_local_expert_mask[eid]; + if(e_mask == 0) + return; // skip empty expert + } + + // cumsum one by one + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 + using d_t = ext_vector_t; + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; + + for(int i = 0; i < loops; i++) + { + int i_token_pack = i * BLOCK_SIZE + threadIdx.x; + r_t x_v = 0; + if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack) + { + x_v = reinterpret_cast(p_expert_mesh + + eid * kargs.mesh_stride)[i_token_pack]; + } + + r_t x_r; +#if 0 + if constexpr(index_pack != 1) + { + // shuffle, we must have contiguout thread holds contiguout token + __syncthreads(); + reinterpret_cast(s)[threadIdx.x] = x_v; + __syncthreads(); + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + x_r[j] = reinterpret_cast(s)[threadIdx.x + j * BLOCK_SIZE]; + }); + } +#else + x_r = x_v; +#endif + { +#if 0 +#pragma unroll + for(int j = 0; j < index_pack / 2; j++) + { + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE; + index_t x = x_d[j]; + int i_topk = x - 1; // topk of this token + int i_show = x != 0 ? 1 : 0; // has this token or not + int cumsum = i_show; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position = cumsum - i_show; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk]; + } + } +#endif + { + d_t i_topk; + d_t i_show; + // = 0; + int cumsum_store = 0; + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + i_topk[j] = static_cast(x_r[j] - 1); + i_show[j] = static_cast(x_r[j] != 0 ? 1 : 0); + cumsum_store += i_show[j]; + }); + int cumsum = cumsum_store; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + prev_cumsum = s[0]; // update the last cumsum + + int position = cumsum - cumsum_store; + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + // int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * + // BLOCK_SIZE; + int i_token = + i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j; + + if(i_show[j]) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk[j]); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]]; + } + position += i_show[j]; + }); + +#if 0 + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2; + index_t x = x_d[j]; + index_t x0 = static_cast(x & 0xffff); + index_t x1 = static_cast(x >> 16); + int i_topk_0 = x0 - 1; // topk of this token + int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not + int i_topk_1 = x1 - 1; // topk of this token + int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not + int cumsum = i_show_0 + i_show_1; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position_0 = cumsum - i_show_0 - i_show_1; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show_0) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_0] = + MOE_SORTING_MOCK_ID(i_token, i_topk_0); +#else + p_sorted_token_ids[e_start + position_0] = i_token; +#endif + p_sorted_weights[e_start + position_0] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0]; + } + + int position_1 = cumsum - i_show_1; + + if(i_show_1) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_1] = + MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1); +#else + p_sorted_token_ids[e_start + position_1] = i_token + 1; +#endif + p_sorted_weights[e_start + position_1] = + p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1]; + } +#endif + } + } + } + + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); +#else + p_sorted_token_ids[i] = tokens; +#endif + p_sorted_weights[i] = static_cast(0.0); + } + } + } +}; + #undef MOE_SORTING_MOCK_ID } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index a98e0d7652..39bc6ca93e 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -50,20 +50,23 @@ struct MoeSortingProblemEx }; template struct MoeSortingProblemMp { // TODO: this kernel only support warp per row using WeightType = remove_cvref_t; + using MeshType = remove_cvref_t; using IndexType = remove_cvref_t; static constexpr index_t SubTokenTile = SubTokenTile_; static constexpr bool LocalExpertMasking = LocalExpertMasking_; static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; - static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4); + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || + SubTokenTile == 8 || SubTokenTile == 16); }; } // namespace ck_tile From 8a0d659f92897e1ae99e4dc0ea4842a2c78170ab Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 6 May 2025 09:24:00 -0500 Subject: [PATCH 097/443] Add FP4 MX MFMA tests (#2151) * Add conversion tests * Fix ctor * Fix nan logic * Fix conversion logic * Permute packed f4_t values * Fix conversion to float, repack vector elements * Fix device tests * Permute elements in a vector * Add a repro test * Add a conversion for a repro test * Update test vectors * Update conversion * Fix the test * Update test vector generator * Fix vector sr conversion * Permute conversion args * Update conversion * Test * Fix packing * Simplify conversion function * Pack conversion in a loop * Pack conversion in a loop * Pack another conversion in a loop * Pack one more conversion in a loop * Pack the last conversion in a loop * Clean up * Add ops * Add tests * Add missing utils * Update reference mx gemm * Add f4x2 init mode * Update host tensor utils * Update chunk size for f4x2 * Add non scaled ops * Add a type utility * Update non scaled reference kernel * Add non scaled tests * Debug mfma arguments * Add more debug info * Update chunk size * Update data layout * Add more debugging * Fix B stride * Fix reference gemm * Fix build * One more reference fix * Add more debug info * Disable some tests * Enable tests * Add fp4 dimensions * Update reference kernels * Temp edits * Remove leftovers * Fix conflicts * Clean up * More clean up * Revert "More clean up" This reverts commit d8d35a0846a8c2f0ccc7defe5f4fc7cc4ef36760. * Add layouts to tests --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> --- include/ck/library/utility/host_tensor.hpp | 21 +- .../library/utility/host_tensor_generator.hpp | 48 ++- include/ck/utility/amd_xdlops.hpp | 122 +++++++ include/ck/utility/data_type.hpp | 7 + .../cpu/reference_gemm.hpp | 20 ++ .../cpu/reference_mx_gemm.hpp | 50 ++- test/mx_mfma_op/mx_mfma_op.cpp | 114 ++++++- test/mx_mfma_op/mx_mfma_op.hpp | 307 +++++++++++++++--- 8 files changed, 610 insertions(+), 79 deletions(-) diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 2cbca29afc..71417ce7bf 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) { os << ck::type_convert(v); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || + std::is_same_v) { const auto packed_floats = ck::type_convert(v); const ck::vector_type vector_of_floats{packed_floats}; @@ -359,7 +360,8 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return (mDesc.GetElementSpaceSize() + 1) / 2; } @@ -514,7 +516,8 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mDesc.GetOffsetFromMultiIndex(is...) / 2; } @@ -527,7 +530,8 @@ struct Tensor template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -540,7 +544,8 @@ struct Tensor template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -552,7 +557,8 @@ struct Tensor T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } @@ -564,7 +570,8 @@ struct Tensor const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 274051da83..785f74a3c0 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,6 +81,18 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4x2_pk_t operator()(Is...) + { + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{value, value})}; + } +}; + template <> struct GeneratorTensor_1 { @@ -209,6 +221,21 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = (std::rand() % (max_value - min_value)) + min_value; + float tmp1 = (std::rand() % (max_value - min_value)) + min_value; + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{tmp0, tmp1})}; + } +}; + template struct GeneratorTensor_3 { @@ -296,6 +323,25 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = float(std::rand()) / float(RAND_MAX); + float tmp1 = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp0 = min_value + tmp0 * (max_value - min_value); + float fp32_tmp1 = min_value + tmp1 * (max_value - min_value); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 71e1937a23..66c4958e1d 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -508,6 +508,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -589,6 +617,40 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> ignore = reg_b; ignore = scale_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; #endif } }; @@ -686,6 +748,39 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf8x32_t& reg_a, const int32_t& scale_a, @@ -748,6 +843,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 79bd717501..a6106bb146 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -470,6 +470,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = f4x2_pk_t::type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 7e2482807d..c8d284a1d7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_a = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); + else + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_b = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); + else + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 649f130c41..e8fdcf1acd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - a_m_k_scaled(m, k) = - type_convert(arg.a_m_k_(m, k)) * - type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + else + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + else + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } } } @@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - b_k_n_scaled(k, n) = - type_convert(arg.b_k_n_(k, n)) * - type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + else + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + else + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } } } diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index f65e89bb82..fddb8288a6 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -6,6 +6,8 @@ #include "mx_mfma_op.hpp" using ck::e8m0_bexp_t; +using ck::f4_t; +using ck::f4x2_pk_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -16,7 +18,7 @@ using ck::type_convert; * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mfma_test(ck::index_t init) +bool run_mfma_km_kn_nm_test(ck::index_t init) { using ALayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -30,7 +32,8 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mfma_kernel = ck::matmul; + const auto mfma_kernel = ck:: + matmul; bool pass = true; @@ -52,15 +55,72 @@ bool run_mfma_test(ck::index_t init) TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) +{ + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); + EXPECT_TRUE(pass); +} + +/** + * @brief Run the test for the given MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mfma_mk_kn_mn_test(ck::index_t init) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using CPUAccType = AccType; + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + + const auto mfma_kernel = ck:: + matmul; + + bool pass = true; + + pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); + + return pass; +} + +TEST(MFMA, FP4MFMA16x16x128) { auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto pass = run_mfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = run_mfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } @@ -70,7 +130,7 @@ TEST(MFMA, FP8MFMA32x32x64) * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mxmfma_test(ck::index_t init) +bool run_mxmfma_mk_kn_mn_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, @@ -88,8 +148,18 @@ bool run_mxmfma_test(ck::index_t init) constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_X = 32; // scaling vector size - const auto mx_mfma_kernel = - ck::matmul; + const auto mx_mfma_kernel = ck::matmul; bool pass = true; @@ -111,14 +181,34 @@ bool run_mxmfma_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA16x16x128) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index d22157c3b3..9ce871cfb1 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -111,7 +112,7 @@ template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off - // Register Mapping for 16x128: || Register Mapping for 32x64: + // Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8: // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -176,13 +177,19 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + using ARawT = typename scalar_type::type; + using AScalarFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; AScalarFragT fragA{}; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + #pragma unroll - for(int chunk = 0; chunk < 2; chunk++) + for(int chunk = 0; chunk < num_chunks; chunk++) { #pragma unroll for(uint32_t i = 0; i < chunk_size; i++) @@ -241,6 +248,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -265,23 +294,34 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; // BLOCK_K is a stride in A matrix - auto startOffset = row_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + auto startOffset = row_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + row_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using ARawT = typename scalar_type::type; using AScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { AFragT frag; - AScalarFragT chunks[2]; + AScalarFragT chunks[num_chunks]; } fragA{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragA.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragA.chunks[1] = *fragPtr; + const AScalarFragT* fragPtr; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) + { + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragA.chunks[chunk_idx] = *fragPtr; + } return fragA.frag; } @@ -339,15 +379,35 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | | BLOCK_M | | BLOCK_M | | BLOCK_M | | || Size | BLOCK_M | | BLOCK_M | | | + // M | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || M | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(AFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where they start - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M)); // Col // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; @@ -369,7 +429,7 @@ template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off - // Register Mapping for 128x16: || Register Mapping for 64x32: + // Register Mapping for 128x16 for FP8: || Register Mapping for 64x32 for FP8: // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -406,6 +466,28 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -430,23 +512,34 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col // BLOCK_K is a stride in B matrix - auto startOffset = col_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + auto startOffset = col_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + col_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using BRawT = typename scalar_type::type; using BScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { BFragT frag; - BScalarFragT chunks[2]; + BScalarFragT chunks[num_chunks]; } fragB{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragB.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragB.chunks[1] = *fragPtr; + const BScalarFragT* fragPtr; + + for(index_t chunk = 0; chunk < num_chunks; chunk++) + { + fragPtr = + reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); + fragB.chunks[chunk] = *fragPtr; + } return fragB.frag; } @@ -506,15 +599,56 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | | BLOCK_N | | BLOCK_N | | BLOCK_N | | || Size | BLOCK_N | | BLOCK_N | | | + // N | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || N | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(0,N) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(0,N) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(0,N) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(0,N) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(0,N) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(0,N) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(0,N) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(0,N) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(0,N) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(0,N) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(0,N) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(0,N) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(0,N) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(0,N) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(0,N) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(0,N) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(BFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where to start - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N), // Row + threadIdx.x % BLOCK_N); // Col // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; @@ -766,15 +900,24 @@ template + int32_t BLOCK_K, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const BType* b, CType* c) { constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -786,10 +929,23 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) auto fragAcc = AccumFragT{0}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_A_col_major(a); - // B = col major, BLOCK_K x BLOCK_N - fragB = load_B_col_major(b); + if constexpr(is_same_v) + { + fragA = load_A_row_major(a); + } + else + { + fragA = load_A_col_major(a); + } + + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = load_B_col_major(b); + } // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N @@ -801,8 +957,14 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_col_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } template + int32_t BLOCK_X, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) { @@ -821,8 +986,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -838,13 +1009,27 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, auto fragXb = BScaleFragT{}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_mx_A_row_major( - a, xa, fragXa); + if constexpr(is_same_v) + { + fragA = + load_mx_A_row_major( + a, xa, fragXa); + } + else + { + printf("This layout is not implemented\n"); + } - // B = col major, BLOCK_K x BLOCK_N - fragB = load_mx_B_col_major( - b, xb, fragXb); + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = + load_mx_B_col_major( + b, xb, fragXb); + } // Scaled Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N @@ -860,8 +1045,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_row_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } /** @@ -993,8 +1184,7 @@ struct TestMXMFMA { case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue( - GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); @@ -1012,11 +1202,9 @@ struct TestMXMFMA a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; - case 3: // expect small round off errors a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); @@ -1026,6 +1214,14 @@ struct TestMXMFMA b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] @@ -1207,6 +1403,11 @@ struct TestMFMA a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + case 4: + // FP4 values case + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); From 769336b6404d36ee6e7ef39baa8fccd3f583a8e7 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 7 May 2025 02:00:39 -0500 Subject: [PATCH 098/443] [CK_TILE] Add type traits to detect tile window types at compile time (#2158) * added WindowType enum to tile_window_structs and static assert checks in computev4 pipeline * added type traits instead of enum to tile_window() and tile_window_linear() with debug comments * removed comments, added documentation and clang format --- include/ck_tile/core/tensor/tile_window.hpp | 78 +++++++++++++++++++ .../core/tensor/tile_window_linear.hpp | 46 +++++++++++ .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 6 ++ 3 files changed, 130 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3bb728df23..716b1f4ecb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1164,4 +1164,82 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a tile window with static distribution. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_distribution : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam NumCoord Number of coordinate dimensions. + */ +template +struct is_tile_window_with_static_distribution< + tile_window_with_static_distribution> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static distribution. + * + * Equivalent to `is_tile_window_with_static_distribution::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_distribution_v = + is_tile_window_with_static_distribution::value; + +/** + * @brief Type trait to determine if a type is a tile window with static lengths. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_lengths : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_lengths` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + */ +template +struct is_tile_window_with_static_lengths< + tile_window_with_static_lengths> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static lengths. + * + * Equivalent to `is_tile_window_with_static_lengths::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_lengths_v = + is_tile_window_with_static_lengths::value; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 1e24e660f6..5ecaf5ca17 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -44,6 +44,7 @@ template struct tile_window_linear { + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -1215,4 +1216,49 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a linear tile window. + * + * Defaults to `false_type`. Specialized to `true_type` for types that match + * `tile_window_linear<...>`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_linear : std::false_type +{ +}; + +/** + * @brief Specialization of `is_tile_window_linear` for `tile_window_linear`. + * + * Evaluates to `true_type` if the type is a `tile_window_linear` with the given template + * parameters. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam LinearBottomDims_ Dimensions of the bottom tensor view that participate in linearization. + */ +template +struct is_tile_window_linear> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a linear tile window. + * + * Equivalent to `is_tile_window_linear::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_linear_v = is_tile_window_linear::value; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 667bb80ce9..6535f612f1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -337,6 +337,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 {0, 0}, BLdsTileDistr); + static_assert( + !(is_tile_window_linear_v)&&!(is_tile_window_linear_v)&&!( + is_tile_window_linear_v< + decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v), + "LDS windows must not be linear"); + Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); From 956fe8f75118de688b1ee9ca8619b2c1dbe35ea1 Mon Sep 17 00:00:00 2001 From: kylasa Date: Wed, 7 May 2025 00:02:59 -0700 Subject: [PATCH 099/443] Simple copy kernel, which can be a tool to experiment with CK_Tile API with minimal code. (#2156) * Test Copy kernel code for testing tile distribution logic * Fix the error * Solved the problem * Updated comments and document formatting * Removed unused tile distribution and code cleanup * Added README.md and formatting for CI/CD. --------- Co-authored-by: ThomasNing --- example/ck_tile/36_copy/CMakeLists.txt | 4 + example/ck_tile/36_copy/README.md | 31 +++++ example/ck_tile/36_copy/test_copy.cpp | 117 ++++++++++++++++ example/ck_tile/36_copy/test_copy.hpp | 178 +++++++++++++++++++++++++ example/ck_tile/CMakeLists.txt | 1 + 5 files changed, 331 insertions(+) create mode 100644 example/ck_tile/36_copy/CMakeLists.txt create mode 100644 example/ck_tile/36_copy/README.md create mode 100644 example/ck_tile/36_copy/test_copy.cpp create mode 100644 example/ck_tile/36_copy/test_copy.hpp diff --git a/example/ck_tile/36_copy/CMakeLists.txt b/example/ck_tile/36_copy/CMakeLists.txt new file mode 100644 index 0000000000..d1b9ba923c --- /dev/null +++ b/example/ck_tile/36_copy/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp) +target_compile_options(test_copy_kernel PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) \ No newline at end of file diff --git a/example/ck_tile/36_copy/README.md b/example/ck_tile/36_copy/README.md new file mode 100644 index 0000000000..7856f0b4bd --- /dev/null +++ b/example/ck_tile/36_copy/README.md @@ -0,0 +1,31 @@ +# Copy Kernel +This folder contains basic setup code designed to provide a platform for novice +CK_Tile kernel developers to test basic functionality with minimal additional +code compared to the functional code. Sample functional code for a simple +tile distribution for DRAM window and LDS window are provided and data is moved +from DRAM to registers, registers to LDS, LDS to registers and finally data +is moved to output DRAM window for a simple copy operation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture +# (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the copy kernel executable +make test_copy -j +``` +This will result in an executable `build/bin/test_copy_kernel` + +## example +``` +args: + -m input matrix rows. (default 64) + -n input matrix cols. (default 8) + -id warp to use for computation. (default 0) + -v validation flag to check device results. (default 1) + -prec datatype precision to use. (default fp16) + -warmup no. of warmup iterations. (default 50) + -repeat no. of iterations for kernel execution time. (default 100) +``` \ No newline at end of file diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp new file mode 100644 index 0000000000..81ea5255fc --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include +#include "test_copy.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "64", "m dimension") + .insert("n", "8", "n dimension") + .insert("id", "0", "warp to use") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "50", "cold iter") + .insert("repeat", "100", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t warp_id = arg_parser.get_int("id"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_ref({m, n}); + ck_tile::HostTensor y_host_dev({m, n}); + + // ck_tile::FillConstant{1.f}(x_host); + ck_tile::half_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + x_host(i, j) = value++; + } + } + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, 4>; + + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Kernel = ck_tile::TileCopy; + + constexpr ck_tile::index_t kBlockSize = 128; + constexpr ck_tile::index_t kBlockPerCu = 1; + std::cout << "block size " << kBlockSize << std::endl; + std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl; + std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N + << std::endl; + std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " " + << BlockWaves::at(ck_tile::number<1>{}) << std::endl; + std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); + + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, x_host); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/36_copy/test_copy.hpp b/example/ck_tile/36_copy/test_copy.hpp new file mode 100644 index 0000000000..8fed22a3d0 --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WaveTile, // warp size, seq + typename Vector> // contiguous elements(vector size) along seq +struct TileCopyShape +{ + // We split Workgroup waves into two specialized groups. + // One for reading data from global -> LDS, the other is doing reduction + static constexpr index_t WaveGroups = 2; + static constexpr index_t MWarps = BlockWaves::at(number<0>{}); + static constexpr index_t NWarps = BlockWaves::at(number<0>{}); + + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WaveTile::at(number<0>{}); + static constexpr index_t Warp_N = WaveTile::at(number<1>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t WarpPerBlock_M = + integer_divide_ceil(BlockWaves::at(number<0>{}), WaveGroups); + static constexpr index_t WarpPerBlock_N = + integer_divide_ceil(BlockWaves::at(number<1>{}), WaveGroups); + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{}); + + static constexpr index_t BlockSize = get_warp_size() * WaveNum; + static constexpr index_t WaveGroupSize = WaveNum / WaveGroups; + static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!"); +}; + +template +struct TileCopyProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +template +struct TileCopy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Vector_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // no. of active warps working in this thread block. + constexpr index_t Y1 = warp_size / X0; // no. of threads in a warp needed along M dimension. + constexpr index_t Y2 = + S::Warp_M / + (Y1 * + Y0); // no. of iterations each warp needs to perform to cover the entire tile window. + + constexpr auto outer_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}; + return make_static_tile_distribution(outer_encoding); + } + + CK_TILE_DEVICE void + operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + + // LDS Data. + __shared__ XDataType x_lds[number{} * number{}]; + XDataType* __restrict__ p_x_lds = static_cast(x_lds); + + const auto x_lds_desc = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, 1), + number{}, + number<1>{}); + + auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_desc, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform( + make_tuple(number{} / S::Vector_N, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = make_tensor_view(p_x_lds, x_lds_block_desc); + + auto x_block_lds_window = + make_tile_window(x_lds_view, + make_tuple(number{}, number{}), + {0, 0}, + MakeDRAMDistribution()); + auto x_block_lds_window_no_dist = make_tile_window( + x_lds_view, make_tuple(number{}, number{}), {0, 0}); + + // Input tensor + const auto iM = get_block_id() * S::Block_M; + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + MakeDRAMDistribution()); + + // Output tensor + const auto y_m = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + auto y_block_window = + make_tile_window(y_m, make_tuple(number{}, number{}), {iM, 0}); + + // Programming logic + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + auto my_id = get_warp_id(); + + auto DramTileDist = x_block_window.get_tile_distribution(); + using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + dram_reg_tile dram_tile; + + if(my_id == warp_id) + { + // load from DRAM to registers + load_tile(dram_tile, x_block_window); + + // store in lds + store_tile(x_block_lds_window_no_dist, dram_tile); + + // read from lds to registers + load_tile(dram_tile, x_block_lds_window); + + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } + __syncthreads(); + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 88efe0d8d9..d479cd35f6 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -19,3 +19,4 @@ add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) +add_subdirectory(36_copy) From 397b9080a217633f3f35d632329b16f4fababdf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 7 May 2025 17:04:31 +0200 Subject: [PATCH 100/443] Move 16x16 grouped conv fwd instances from comp header (#2165) * Move 16x16 grouped conv fwd instances from comp header * Improvements --- ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 21 +-- .../device_grouped_conv_fwd_xdl_instance.hpp | 57 ++++++ .../gpu/grouped_convolution_forward.hpp | 14 ++ .../gpu/grouped_convolution_forward_xdl.inc | 168 ++++++++++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 6 + ..._ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp | 55 ++++++ ...wd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp | 14 ++ ...l_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp | 54 ++++++ ...fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp | 14 ++ ...l_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp | 54 ++++++ ...fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp | 14 ++ ..._nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp | 57 ++++++ ...l_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp | 56 ++++++ ...l_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp | 56 ++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 6 + ...hwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp | 55 ++++++ ...dhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp | 54 ++++++ ...dhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp | 54 ++++++ ...cdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp | 56 ++++++ ...gcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp | 55 ++++++ ...gcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp | 55 ++++++ 21 files changed, 957 insertions(+), 18 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index 6c0ba2f932..158ed26ec4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -4,7 +4,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -90,12 +89,7 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // mfma 16x16 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; @@ -146,12 +140,7 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - // mfma 16x16 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> // clang-format on >; @@ -195,11 +184,7 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - // mfma 16x16 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding,1,256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index c9ea462316..f5397308dc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -97,6 +97,25 @@ using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bf16_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_xdl_f16_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + // clang-format on + >; + template ; +template +using device_grouped_conv_fwd_xdl_f32_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> + // clang-format on + >; + template ) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( @@ -221,6 +222,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( @@ -243,6 +245,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -288,6 +291,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -484,6 +491,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -503,6 +511,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -536,6 +546,7 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -153,6 +167,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -169,6 +197,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -267,6 +309,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -283,6 +339,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -299,6 +369,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -382,6 +466,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -398,6 +496,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP8 @@ -446,6 +558,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -532,6 +658,20 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -548,6 +688,20 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -564,6 +718,20 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 3a101baac0..eba6fd789e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp # NGCHW, GKYXC, NGKHW xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp @@ -19,6 +22,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp # large tensor # NHWGC, GKYXC, NHWGK xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..0843325287 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 6c5d9b5b94..4ca1b2b85e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp new file mode 100644 index 0000000000..a82e800bb1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index f1ccad2add..e3a12fd5f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp new file mode 100644 index 0000000000..5918f2479f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp index de7e416e48..467a33deb3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..5b8b62010a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp new file mode 100644 index 0000000000..7ca27e21a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp new file mode 100644 index 0000000000..74cdbde0ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index eeea4aae6d..f55bdd45c9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -7,10 +7,16 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..8f113b5234 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp new file mode 100644 index 0000000000..1395447660 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp new file mode 100644 index 0000000000..43b3565c74 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..3b5068d605 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp new file mode 100644 index 0000000000..0ddf5bfa48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp new file mode 100644 index 0000000000..dc4f7be9c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From cb07ad84d5b8a6a796dff34c5d990476b6693b16 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 7 May 2025 19:46:53 +0200 Subject: [PATCH 101/443] fix for default epilogue (#2167) --- .../ops/epilogue/default_2d_epilogue.hpp | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 1d6a99eb4b..a2915f5c8f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,14 +15,16 @@ template + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; template -struct DefaultGemm2DEpilogueProblem - : public Default2DEpilogueProblem + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> +struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -58,14 +65,13 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) { @@ -73,7 +79,7 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); } @@ -85,7 +91,7 @@ struct Default2DEpilogue } else { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } From c7b8e86e342a77f9176b0f4688282fad03eb863b Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Wed, 7 May 2025 18:37:31 -0700 Subject: [PATCH 102/443] [CK_Tile] Simplified Mem pipeline (#2159) * simplify code * compiled the code * Simplified example and codegen for mem pipeline * Reveting config and universal gemm example * clang formatted * remove comments * clang formatted * Add memory operation changes for defualt pipeline * fix config file --------- Co-authored-by: ThomasNing --- example/ck_tile/03_gemm/universal_gemm.cpp | 81 ++++--------- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 78 ++++-------- .../gemm/configs/instance_combination.json | 2 +- tile_engine/ops/gemm/gemm_instance_builder.py | 111 +++++++++--------- 4 files changed, 107 insertions(+), 165 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index e6a2811918..b60a3b274b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,6 +12,19 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template {}, @@ -176,60 +188,17 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ck_tile::integral_constant{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail(ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); + #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { @@ -259,7 +228,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& else if(tail_num == ck_tile::TailNumber::Even) { RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 0329f16416..85742cb3de 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -63,6 +63,19 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -251,60 +264,17 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::TailNumber::Full>{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); } if constexpr(PipelineType == GemmPipelineType::CompV4) diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index 66dbdafa11..53197ada6c 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -19,7 +19,7 @@ "values": [256] }, "tile_k": { - "values": [64, 32] + "values": [32] }, "warp_m": { "values": [2] diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index a748c35feb..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -37,7 +37,9 @@ DEFAULT_EPILOGUE = """ WarpTileM, WarpTileN, WarpTileK, - UniversalGemmProblem::TransposeC>>; + UniversalGemmProblem::TransposeC, + true, + memory_operation>>; """ CSHUFFLE_EPILOGUE = """ @@ -55,22 +57,23 @@ CSHUFFLE_EPILOGUE = """ WarpTileM, WarpTileN, WarpTileK, - UniversalGemmProblem::TransposeC>>; + UniversalGemmProblem::TransposeC, + memory_operation>>; """ HOT_LOOP_FALSE = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else @@ -79,68 +82,43 @@ HOT_LOOP_FALSE = """ } """ RUN_MEM = """ - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, + // Handle One and Full cases directly + if (tail_num == ck_tile::TailNumber::One) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, + } else if (tail_num == ck_tile::TailNumber::Full) { + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } + // Variadic call using fold expression + auto check_tail = [&](auto... TNs) { + (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); + }; - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - throw std::runtime_error("The tile number is wrong! It should not exceed the prefetch stage numbers"); - } + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{} + ); """ RUN_COMPV3 = """ if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else @@ -152,12 +130,12 @@ RUN_COMPV3 = """ RUN_COMPV4 = """ if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, + RunSplitk(ck_tile::bool_constant{}, ck_tile::integral_constant{}); } """ @@ -347,6 +325,15 @@ namespace {group_name} {{ kPadM: bool, kPadN: bool, kPadK: bool) -> str: """Generate kernel struct template""" return f""" +template +void try_run(ck_tile::TailNumber tn) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if (tn == TN) {{ + RunSplitk(ck_tile::bool_constant{{}}, + ck_tile::integral_constant{{}}); + }} + }} +}} template {{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + if(has_hot_loop) {{ {HOT_LOOP_TRUE[pipeline]} }} else {{ @@ -450,6 +452,7 @@ struct GemmKernel {{ return ave_time; }} + static std::string get_name() {{ return std::string("GemmKernel Date: Thu, 8 May 2025 12:59:57 +0800 Subject: [PATCH 103/443] Flatmm merge (#2168) * sync with function interface of cshuffleepiloge,fix flatmm build fail * move code from solin/flatmm which add mfma16*16*32fp8 and optimize flatmm --------- Co-authored-by: solin --- example/ck_tile/18_flatmm/CMakeLists.txt | 3 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 162 ++++++++---- example/ck_tile/18_flatmm/flatmm_basic.hpp | 52 +++- .../ck_tile/18_flatmm/run_flatmm_example.inc | 79 +++--- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 77 +----- .../ops/flatmm/kernel/flatmm_kernel.hpp | 28 +-- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 234 +++++++++++++++++- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 97 +++++++- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 8 + .../warp/warp_gemm_attribute_mfma_impl.hpp | 2 +- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 + 11 files changed, 552 insertions(+), 192 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 9fbe65e3a7..f4d823e91a 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,5 +3,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 05d0c73b7e..5f2c2a5aab 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -template +template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -23,18 +29,32 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr int kBlockPerCu = 2; // This part comes from the Codegen +#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t K_Tile = 128; constexpr ck_tile::index_t M_Warp = 1; constexpr ck_tile::index_t N_Warp = 4; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; +#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 128; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 8; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; +#endif using CodegenFlatmmShape = ck_tile::TileFlatmmShape, ck_tile::sequence, @@ -49,54 +69,112 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con AccDataType, CodegenFlatmmShape, CodegenGemmTraits>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; - using CodegenFlatmmPipeline = - ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::FlatmmKernel; + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; - auto kargs = Kernel::MakeKernelArgs(args); + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + auto kargs = Kernel::MakeKernelArgs(args); - if(!Kernel::IsSupportedArgument(kargs)) + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(ck_tile::integral_constant{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_flatmm_example.inc" +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } + return -1; +} + int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 355ac45ebe..bbce978724 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -31,7 +31,7 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif -template +template struct GemmBasicTypeConfig; template <> @@ -44,9 +44,47 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; template <> struct DataTypeTraits { @@ -65,13 +103,11 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; - -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template +struct is_8bit_type + : std::bool_constant || std::is_same_v> +{ +}; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 864d888074..15a9df2c0c 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -1,6 +1,20 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include + +template +constexpr const char* DataTypeToString() { + if constexpr (std::is_same_v) { + return "fp16"; + } else if constexpr (std::is_same_v) { + return "fp8"; + } else if constexpr (std::is_same_v) { + return "bf8"; + } else { + return "unknown"; + } +} template static constexpr inline auto is_row_major(Layout layout_) @@ -11,7 +25,7 @@ static constexpr inline auto is_row_major(Layout layout_) // mfma_type, 0:32x32, 1:16x16 template -auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -29,13 +43,13 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) { ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) { ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); std::copy(t.begin(), t.end(), t_view.begin()); @@ -44,6 +58,7 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma return t; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -64,7 +79,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, ck_tile::DeviceMem& c_dev_buf, @@ -91,7 +112,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = flatmm_calc( + float ave_time = flatmm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -100,7 +121,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Flatmm kernel with M =" << M << " N =" << N << " K =" << K + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -108,7 +129,10 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, return ave_time; } -template +template int run_flatmm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -119,6 +143,11 @@ int run_flatmm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -154,11 +183,17 @@ int run_flatmm_example_with_layouts(int argc, // do pre-shuffle std::string mfma = arg_parser.get_str("prec"); - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, 0); +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + ck_tile::index_t mfma_type = 1; +#else + ck_tile::index_t mfma_type = 0; +#endif + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); - invoke_flatmm(a_dev_buf, + invoke_flatmm( + a_dev_buf, b_shuffle_dev_buf, c_dev_buf, M, @@ -184,7 +219,7 @@ int run_flatmm_example_with_layouts(int argc, a_host, b_origin_host, c_ref_host); const float max_accumulated_value = *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_rslt_host, c_ref_host, "Error: Incorrect results!", @@ -242,7 +277,7 @@ int run_flatmm_example_with_layouts(int argc, c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); const float max_accumulated_value = *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", @@ -257,25 +292,3 @@ int run_flatmm_example_with_layouts(int argc, return pass; } - -int run_flatmm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - - if(a_layout == "R" && b_layout == "C") - { - return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); - } -} diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 935eb2c028..18b2fe6483 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -66,76 +66,24 @@ struct BlockFlatmmASmemBSmemCRegV1 } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindow& a_block_window, - const BFlatBlockWindow& b_flat_block_window) const + ABlockWindow& a_warp_windows, + BFlatBlockTensor& b_warp_tensor) const { - static_assert(std::is_same_v && - std::is_same_v && - std::is_same_v, - "wrong!"); - constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp; - constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp; - - const index_t iMWarp = get_warp_id() / NWarp; - - // construct A-warp-window - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // construct Bflat-warp-window - auto b_flat_warp_windows_tmp = b_flat_block_window; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp; - - move_tile_window(b_flat_warp_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - }); - }); - - // auto b_warp_windows = b_origin_warp_windows; - auto b_warp_windows = b_flat_warp_windows; - using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; @@ -150,9 +98,6 @@ struct BlockFlatmmASmemBSmemCRegV1 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -161,7 +106,7 @@ struct BlockFlatmmASmemBSmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -172,16 +117,6 @@ struct BlockFlatmmASmemBSmemCRegV1 }); }); } - - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BFlatBlockWindow& b_flat_block_window) const - { - auto c_block_tensor = MakeCBlockTile(); - operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window); - return c_block_tensor; - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index eb45e6c0bd..a9ed1519e6 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -321,7 +321,7 @@ struct FlatmmKernel const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), @@ -330,7 +330,7 @@ struct FlatmmKernel } else { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), @@ -426,7 +426,6 @@ struct FlatmmKernel return make_tuple(a_block_window, b_flat_block_window, c_block_window); } - template CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, const BDataType* b_flat_ptr, CDataType* c_ptr, @@ -438,7 +437,8 @@ struct FlatmmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -453,9 +453,8 @@ struct FlatmmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const @@ -475,21 +474,12 @@ struct FlatmmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } - else - { - // Do not compile in case where we have unsupported - // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunFlatmm( - a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } - } } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 611aff318f..2ff9d1ebf0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -73,6 +73,83 @@ struct FlatmmPipelineAGmemBGmemCRegV1 return PipelinePolicy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + // constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + +#elif defined(USING_MFMA_32x32x16) + static_for<0, + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA +#endif + } + template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, @@ -89,6 +166,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); @@ -112,6 +208,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1 auto a_lds_gemm_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_warp_window_tmp = make_tile_window( + a_lds_gemm_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + // Block GEMM auto block_flatmm = BlockFlatmm(); @@ -126,16 +241,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1 b_flat_distribution); // Acc register tile - auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){}; + auto c_block_tile = block_flatmm.MakeCBlockTile(); // prefetch // global read 0 auto a_block_tile = load_tile(a_copy_dram_window); + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_2; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + { // move to 1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -152,40 +296,116 @@ struct FlatmmPipelineAGmemBGmemCRegV1 { store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); } + block_sync_lds(); } - index_t iCounter = num_loop - 1; + index_t iCounter = num_loop / 2 - 1; while(iCounter > 0) { // global read i + 1 a_block_tile = load_tile(a_copy_dram_window); - block_sync_lds(); - // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); block_sync_lds(); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move to i + 2 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window, a_block_tile_tmp); + HotLoopScheduler(); + block_sync_lds(); + + // iCounter--; + + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // move to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // LDS write i + 1 + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + HotLoopScheduler(); + block_sync_lds(); + iCounter--; } // tail { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + // move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + HotLoopScheduler(); block_sync_lds(); // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); } return c_block_tile; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index d1aac07d54..474924ec84 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -19,23 +19,100 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; +#elif defined(USING_MFMA_32x32x16) + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, number<1>{}); constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; +#endif +/*xor*/ +#if 0 + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif return a_lds_block_desc; } @@ -58,7 +135,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - return Problem::VectorLoadSize; + return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); } template @@ -82,7 +159,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t KPack = GetSmemPackA(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0)) + if constexpr(get_warp_size() >= (K2 * M0)) { constexpr index_t K1 = get_warp_size() / (K2 * M0); constexpr index_t K0 = BlockSize / get_warp_size(); @@ -209,7 +286,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * M0) == 0) + if constexpr(warp_size >= (K2 * M0)) { constexpr index_t K1 = warp_size / (K2 * M0); constexpr index_t K0 = kBlockSize / warp_size; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e75aca1d91..c98d46e3a0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -193,6 +193,14 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 96c3c3d29f..69d22496f1 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1022,7 +1022,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } else if constexpr(std::is_same_v && std::is_same_v) { - DISPATCH_MFMA_("mfma_f32_116x16x32_fp8_bf8", "+v", "v", "v", "v") + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v") } else if constexpr(std::is_same_v && std::is_same_v) { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 64bd61a3dc..b2f5d56d01 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -57,6 +57,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; @@ -65,6 +66,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; From a32d9077710d8c99283be86565a1e9f9a5aa1671 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Wed, 7 May 2025 23:09:22 -0700 Subject: [PATCH 104/443] Disable the SMFMA instruction for gfx90a. (#2174) * remove smfma for gfx90a * clang formatted --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 3 ++- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c98d46e3a0..61c61c2d9a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -97,12 +97,13 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity - +#if defined(__gfx94__) || defined(__gfx950__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; +#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..c00554df8f 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,7 +535,11 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" +#if defined(__gfx908__) + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#else + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#endif""" content += f""" }} else {{""" for tile in tile_params: From c757046d49e5e5bbd3b3c9bfda95cd093e70f0e8 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Thu, 8 May 2025 00:07:03 -0700 Subject: [PATCH 105/443] Revert "Disable the SMFMA instruction for gfx90a. (#2174)" (#2175) This reverts commit a32d9077710d8c99283be86565a1e9f9a5aa1671. --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 3 +-- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 61c61c2d9a..c98d46e3a0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -97,13 +97,12 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity -#if defined(__gfx94__) || defined(__gfx950__) + using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; -#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c00554df8f..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,11 +535,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" -#if defined(__gfx908__) - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#else - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#endif""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: From cb27e7c77fe807dbdc763feb128bbd127f49b4c8 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Thu, 8 May 2025 13:26:03 -0600 Subject: [PATCH 106/443] Ensure MX GEMM Instances can be Cross-Compiled for Multiple Architectures (#2171) * Re-enable MX GEMM instances * Fix compilation error when building MX GEMM for multiple architectures --- .../gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 4 ++-- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 5 ++++- .../device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp | 4 +--- .../device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp | 4 +--- .../device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp | 7 +------ .../device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp | 7 +------ 7 files changed, 11 insertions(+), 22 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index c37af49387..2c34be9007 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -714,7 +714,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> -#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp index 5b0c5137b3..d3f74b2907 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -39,12 +39,11 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -52,7 +51,6 @@ using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> -#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp index 8e25bcc25f..ac09df7ea2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -39,21 +39,16 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> - -//Require verification - //DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> -#endif // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp index 5fefb57257..68363de523 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -39,21 +39,16 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< -// clang-format off + // clang-format off //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if defined(__gfx950__) DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> - - //Require verification - //DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, -#endif // clang-format on >; From 3448e12609f9c8a623e31e3eadc2617928f2780c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 8 May 2025 13:29:14 -0700 Subject: [PATCH 107/443] Generate ckProfiler package for gfx942 only. (#2180) * build CI for gfx942 exclusively * run the last stage in a docker with user jenkins * update the image for the last stage * ignore perf_log if not found * archive and store all packages * use ccache for building packages --- Jenkinsfile | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index a9d30d9f71..2ad96ed44b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -116,7 +116,7 @@ def getDockerImage(Map conf=[:]){ def retimage try { - echo "Pulling down image: ${image}" + echo "Pulling image: ${image}" retimage = docker.image("${image}") withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() @@ -335,8 +335,8 @@ def cmake_build(Map conf=[:]){ } } - // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { + // Only archive from develop + if (package_build == true && env.BRANCH_NAME == "develop") { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } //check the node gpu architecture @@ -539,13 +539,16 @@ def Build_CK(Map conf=[:]){ """ } dir("build"){ - if (params.RUN_FULL_QA && arch_type == 1 ){ + if (params.RUN_FULL_QA && arch_type == 2 ){ // build deb packages for all gfx9 targets on gfx90a system and prepare to export echo "Build ckProfiler package" sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash includes: "ckprofiler_0.2.0_amd64.deb", name: "ckprofiler_0.2.0_amd64.deb" + archiveArtifacts artifacts: 'composablekernel*.deb' + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' + sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' + sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" } } // run performance tests, stash the logs, results will be processed on the master node @@ -654,7 +657,8 @@ def Build_CK_and_Reboot(Map conf=[:]){ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - def image = getDockerImageName() + //use older image that has user jenkins + def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group @@ -667,12 +671,17 @@ def process_results(Map conf=[:]){ def retimage gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e + catch(Exception ex) + { + error "Unable to locate image: ${image}" } } @@ -700,9 +709,14 @@ def process_results(Map conf=[:]){ } if (params.RUN_FULL_QA){ // unstash perf files to master - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - unstash "perf_log" + unstash "packages" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" + try{ + unstash "perf_log" + } + catch(Exception err){ + echo "could not locate perf_log: ${err.getMessage()}." + } try{ unstash "perf_log_gfx11" unstash "perf_log_gfx12" @@ -1114,11 +1128,11 @@ pipeline { agent{ label rocmnode("gfx942") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx90a;gfx942" \ + -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx90a;gfx942" \ + -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } From ef72a4b9bc2e5ddc63d9138cae4e5eba23d35b16 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 9 May 2025 00:18:07 -0700 Subject: [PATCH 108/443] Disable SMFMA for gfx90a (#2182) --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 +++++++++- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c98d46e3a0..5cc5ddc70e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -97,12 +97,20 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity - +#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; +#else // gfx 90a does not support smfmac +using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, + 2>>; +using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, + 2>>; +#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..c00554df8f 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,7 +535,11 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" +#if defined(__gfx908__) + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#else + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); +#endif""" content += f""" }} else {{""" for tile in tile_params: From a23390163d604d9f00ea43e920822ac7cfb0884f Mon Sep 17 00:00:00 2001 From: Mingtao Gu <145657261+mtgu0705@users.noreply.github.com> Date: Fri, 9 May 2025 23:25:31 +0800 Subject: [PATCH 109/443] fix moe gemm2 for gfx950 (#2164) Co-authored-by: mtgu0705 --- example/65_gemm_multiply_multiply/CMakeLists.txt | 2 +- example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 5d2a097576..8d51d43c65 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -7,7 +7,7 @@ add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_mul add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) -list(APPEND gpu_list gfx942) +list(APPEND gpu_list gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index b9621cc9b3..3745e3d0af 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -281,7 +281,7 @@ int main(int argc, char* argv[]) break; case 4: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); From 6b1a339b6faca7e423fdbce67a40a8fca7445abd Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Fri, 9 May 2025 09:01:06 -0700 Subject: [PATCH 110/443] Fix grouped conv bwd data tests on gfx950 (#2173) --- ...ice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 3028cd7cbc..41f596d160 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -179,8 +179,7 @@ __global__ void const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); @@ -251,8 +250,7 @@ __global__ void const ComputePtrOffsetOfN compute_ptr_offset_of_n, const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); const index_t k_idx = From 6fddb5708ca28a84519675ffd3f0ca5c25442706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 9 May 2025 22:52:34 +0200 Subject: [PATCH 111/443] Add grouped conv fwd bias relu instances (#2179) * Add grouped conv fwd bias relu instances * fixes * fix --- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 174 ++++---- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 250 ++++++------ .../element/binary_element_wise_operation.hpp | 8 + ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 131 +++--- .../device_grouped_conv_fwd_xdl_instance.hpp | 383 ++++++++++-------- ...ped_conv_fwd_xdl_large_tensor_instance.hpp | 37 +- ...vice_grouped_conv_fwd_xdl_mem_instance.hpp | 171 ++++---- ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 63 +-- .../grouped_convolution_forward_bias_relu.hpp | 141 +++++++ ...uped_convolution_forward_bias_relu_xdl.inc | 242 +++++++++++ .../CMakeLists.txt | 16 + ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 67 +++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 61 +++ ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 67 +++ ..._nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp | 60 +++ ...lu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 60 +++ ...tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 41 ++ ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 63 +++ ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 63 +++ ...groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 80 ++++ .../CMakeLists.txt | 16 + ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 127 ++++++ ...hwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp | 58 +++ ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 58 +++ ...sor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 41 ++ ..._gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp | 61 +++ ..._gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp | 61 +++ ...ups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 51 +++ ...rofile_grouped_conv_fwd_bias_relu_impl.hpp | 277 +++++++++++++ .../profile_grouped_conv_fwd_impl.hpp | 2 +- test/CMakeLists.txt | 1 + .../CMakeLists.txt | 4 + .../test_grouped_convnd_fwd_bias_relu.cpp | 92 +++++ 33 files changed, 2477 insertions(+), 550 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp create mode 100644 test/grouped_convnd_fwd_bias_relu/CMakeLists.txt create mode 100644 test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index a93e6ded96..bebcd72ceb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -279,9 +279,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiD = DsDataType::Size() > 0; static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; - // multi ABD not supported - static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported"); - static constexpr index_t NumATensor = GetNumABTensors(); static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -1080,91 +1077,96 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(!isMultiABD) { - const index_t a_grid_size = - arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( - arg.a_in_transpose_desc_); - const index_t b_grid_size = - arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( - arg.b_in_transpose_desc_); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t b_grid_size = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); - ADataType* p_a_out_grid = type_convert(arg.p_workspace_); - BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - auto kernel_transpose = kernel_elementwise_dual, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - Block2TileMapElementwise, - element_wise::PassThrough>; + auto kernel_transpose = + kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; - avg_time += launch_and_time_kernel(stream_config, - kernel_transpose, - dim3(a_grid_size + b_grid_size), - dim3(ElementwiseBlocksize), - 0, - make_tuple(arg.a_in_transpose_desc_), - make_tuple(arg.b_in_transpose_desc_), - make_tuple(arg.a_out_transpose_desc_), - make_tuple(arg.b_out_transpose_desc_), - make_tuple(arg.p_a_grid_), - make_tuple(arg.p_b_grid_), - make_tuple(p_a_out_grid), - make_tuple(p_b_out_grid), - arg.elementwise_block_2_ctile_map_transpose_a_, - arg.elementwise_block_2_ctile_map_transpose_b_, - element_wise::PassThrough{}, - a_grid_size); + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } } - - avg_time += RunGemm(arg, stream_config); - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - const index_t grid_size = - arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( - arg.e_in_transpose_desc_); - - const EDataType* p_e_in_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - - EDataType* p_e_out_grid = arg.p_e_grid_; - - auto kernel_transpose = kernel_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - element_wise::PassThrough>; - - avg_time += launch_and_time_kernel(stream_config, - kernel_transpose, - dim3(grid_size), - dim3(ElementwiseBlocksize), - 0, - make_tuple(arg.e_in_transpose_desc_), - make_tuple(arg.e_out_transpose_desc_), - make_tuple(p_e_in_grid), - make_tuple(p_e_out_grid), - arg.elementwise_block_2_ctile_map_transpose_e_, - element_wise::PassThrough{}); - } - return avg_time; } @@ -1182,6 +1184,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t G = arg.b_g_k_c_xs_lengths_[I0]; const index_t K = arg.b_g_k_c_xs_lengths_[I1]; const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(isMultiABD) + { + return false; + } // check device if(get_device_name() == "gfx908") diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index b2903121b1..3c34d77cc9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -192,7 +192,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; - static_assert(NumDTensor == 0, "MultiD not supported."); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -440,89 +439,94 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - // Perform grouped gemm, generate array of tranformer for convolution - Array conv_to_gemm_transformer_arr; - Array a_grid_ptrs; - Array c_grid_ptrs; - - ck::tie(conv_to_gemm_transformer_arr, - a_grid_ptrs, - c_grid_ptrs, - gemms_count_, - is_split_valid_) = - GenerateConvToGemmTransforms( - ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, - a_g_n_c_wis_strides_, - b_g_k_c_xs_lengths_, - b_g_k_c_xs_strides_, - e_g_n_k_wos_lengths_, - e_g_n_k_wos_strides_, - conv_filter_strides_, - conv_filter_dilations_, - input_left_pads_, - input_right_pads_}, - static_cast(p_a), - static_cast(p_e)); - - grid_size_ = 0; - valid_gemms_count_ = 0; - - if(is_split_valid_) + if constexpr(NumDTensor == 0) { - // Create GemmArg for each gemm(conv) - for(index_t i = 0; i < gemms_count_; i++) + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array c_grid_ptrs; + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) { - const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( - conv_to_gemm_transformer_arr[i])}; - const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( - conv_to_gemm_transformer_arr[i])}; - const auto e_grid_desc_m_n = - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); - - const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - const index_t grid_size_grp = - block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); - - const index_t BlockStart = grid_size_; - const index_t BlockEnd = grid_size_ + grid_size_grp; - - grid_size_ += grid_size_grp; - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - Tuple<>{}, - e_grid_desc_m_n, - block_2_etile_map)) + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) { + const AGridDesc_M_K a_grid_desc_m_k{ + DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{ + DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( + conv_to_gemm_transformer_arr[i]); - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - valid_gemms_count_++; + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + Tuple<>{}, + e_grid_desc_m_n, + block_2_etile_map)) + { + + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } } + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } - // N is the same for all convs - conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; } - - // Strides for G and N remain the same - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; - compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; } void Print() const @@ -578,55 +582,63 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(stream_config.log_level_ > 0) + if constexpr(NumDTensor == 0) { - arg.Print(); - } + if(stream_config.log_level_ > 0) + { + arg.Print(); + } - const index_t num_workgroups_per_Conv_N = - arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; - const index_t gdx = arg.grid_size_; - const index_t gdy = arg.num_group_; - const index_t gdz = num_workgroups_per_Conv_N; + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; - // K is constant for all gemms - const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * - arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< - GridwiseGemm, - MaxGemmsNum, - GemmArgs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; - return launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg.gemm_desc_kernel_args_, - arg.gemms_count_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); - }; + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } } else { - return launch_kernel(integral_constant{}); + return 0.f; } } @@ -643,6 +655,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(NumDTensor != 0) + { + return false; + } // Check if all descs are valid if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_)) diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 530876650e..0e58d5acb4 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -404,6 +404,14 @@ struct AddRelu y = a > type_convert(0.0f) ? a : type_convert(0.0f); }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float a = type_convert(x0) + type_convert(x1); + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + template <> __host__ __device__ constexpr void operator()(int& y, const int& x0, const int8_t& x1) const diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index 158ed26ec4..17ffa65d1c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -55,14 +56,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -71,7 +74,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -79,17 +84,17 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; @@ -99,15 +104,17 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5> // clang-format on >; @@ -117,14 +124,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -133,14 +142,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> // clang-format on >; @@ -150,22 +161,24 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -174,17 +187,19 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -194,14 +209,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -210,14 +227,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> // clang-format on >; @@ -227,18 +246,20 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index f5397308dc..df24b4cbcb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -51,7 +52,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -59,7 +62,7 @@ using device_grouped_conv_fwd_xdl_bf16_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -68,7 +71,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -76,24 +81,24 @@ using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -102,17 +107,19 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_16x16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; @@ -121,7 +128,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -129,7 +138,7 @@ using device_grouped_conv_fwd_xdl_f16_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -138,7 +147,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -146,24 +157,24 @@ using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -172,17 +183,19 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_16x16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; @@ -191,7 +204,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -199,7 +214,7 @@ using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1> // clang-format on >; @@ -208,7 +223,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -216,24 +233,24 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; @@ -242,7 +259,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_16x16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -250,9 +269,9 @@ using device_grouped_conv_fwd_xdl_f32_16x16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> // clang-format on >; @@ -261,7 +280,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -269,7 +290,7 @@ using device_grouped_conv_fwd_xdl_int8_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -278,7 +299,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -286,24 +309,24 @@ using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -312,7 +335,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -321,24 +346,24 @@ using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_FP8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> #endif // clang-format on >; @@ -348,7 +373,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -357,24 +384,24 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_FP8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> #endif // clang-format on >; @@ -384,7 +411,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -393,24 +422,24 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_BF8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> #endif // clang-format on >; @@ -420,7 +449,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| @@ -429,24 +460,24 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> #endif // clang-format on >; @@ -456,7 +487,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| @@ -465,24 +498,24 @@ using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index 0a85cde3bc..6bb6d255f3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -25,6 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -36,7 +37,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -44,10 +47,10 @@ using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -56,7 +59,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -64,10 +69,10 @@ using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -76,7 +81,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -84,9 +91,9 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> // clang-format on >; @@ -95,7 +102,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -103,9 +112,9 @@ using device_grouped_conv_fwd_xdl_large_tensor_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 1f381af08c..195367ffd7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -52,7 +53,9 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -60,27 +63,27 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Latency friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -90,34 +93,36 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -127,30 +132,32 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -160,34 +167,36 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 153cc61b09..182c785978 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -25,6 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -38,7 +39,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| @@ -46,9 +49,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> // clang-format on >; @@ -58,16 +61,18 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> // clang-format on >; @@ -76,7 +81,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -84,9 +91,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> // clang-format on >; @@ -96,7 +103,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -104,9 +113,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> // clang-format on >; @@ -115,7 +124,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -123,9 +134,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> // clang-format on >; @@ -134,7 +145,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -142,9 +155,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 32> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp new file mode 100644 index 0000000000..d873edadba --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_bias_relu_xdl.inc" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL + // layout NHWGC/GKYXC/NHWGK + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_XDL + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc new file mode 100644 index 0000000000..1935f123a8 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..98b0b1c4cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +add_instance_library(device_grouped_conv2d_fwd_bias_relu_instance + xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp + + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..75acd604ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..69a8a4bd9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..043c724e4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..c58631e169 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..cd80f2875f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..a6286b55e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..0736325b05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..0d35ab1b05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..253e8b196e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); + } + else + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..afdddfec70 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +) + +add_instance_library(device_grouped_conv3d_fwd_bias_relu_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..9819f0ea0b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..dc3fc7a4bf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..a9a8ff8459 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..e58e879973 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..e76052c6e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..0593f3f46a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6552f26f88 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 0000000000..9d38263d4e --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + // run reference op + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + std::array, 1> d_tensors = {bias}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + // workspace_sz will be equal to 0 for other layout than NGCHW + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {e_g_n_k_wos_lengths}, + {e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 4bfbca5437..dfa6bc1edd 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6bde1140d9..69ffb94488 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -255,6 +255,7 @@ add_subdirectory(reduce) add_subdirectory(convnd_fwd) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_convnd_fwd) +add_subdirectory(grouped_convnd_fwd_bias_relu) add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) diff --git a/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt b/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..680a92b19c --- /dev/null +++ b/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance) +endif() diff --git a/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp b/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp new file mode 100644 index 0000000000..c508235d9c --- /dev/null +++ b/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types>; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwd2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwd2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} From d8faf1c6a161ddcee98e9dfca3cc00941eec9f61 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Sat, 10 May 2025 22:40:05 -0700 Subject: [PATCH 112/443] Support for swizzle and transpose for MFMA_16x16x32_F16/BF16 (#2172) * Changes for updating tile distribution for shuffle and transpose * Fixed swizzle and transpose, removed comments * clang formatted * Adding support for bf16 type * Addressing review comments --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 5cc5ddc70e..5ed97dc05c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -77,6 +77,18 @@ using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = 2>>; #endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl, + 1>>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl, + 1>>; +#endif + #if defined(__gfx950__) using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl Date: Mon, 12 May 2025 00:41:45 -0700 Subject: [PATCH 113/443] Vectorized Transpose for Batched Transpose CK Tile Operator (#2131) * Shared Memory for single data point * CKTile Transpose vectorize CP1 * CKTile Transpose vectorize CP2 * CKTile Transpose vectorize CP2.1 * fixed the compile error of the transpose tile 2d * Have the correct result for the current test sample * Changes to printing tensor * fp8 support added * Debugging for transpose * solving the corner issue * Changed padding flag * Intermideate Debugging * Intermidiate Debugging * Intermediate Debugging * Finished debugging of the transpose op * Code Cleanup * Adding edge case smoke tests * Adding Transpose test to CI/CD * Adding Transpose test to CI/CD * Adding Transpose test to CI/CD * Addressing Review Comment * Addressing Comments * Addressing Comments * Measuring Perf Tests * Code Cleanup * Changlog * Added the running iterations * clang format * Fix the changelog * Fix the compilation error * change the printing factor --------- Co-authored-by: ThruptiRajLakshmanaGowda --- CHANGELOG.md | 2 +- Jenkinsfile | 73 ++++++++++++++- .../ck_tile/35_batched_transpose/README.md | 2 + .../batched_transpose_api.cpp | 89 +++++++++++++------ .../batched_transpose_example.cpp | 43 ++++----- .../35_batched_transpose/script/perf_test.sh | 11 +++ .../script/run_full_test.sh | 38 ++++++++ .../35_batched_transpose/script/smoke_test.sh | 20 ++++- include/ck_tile/core/tensor/tensor_view.hpp | 17 +--- .../ck_tile/core/tensor/transpose_tile.hpp | 23 +++-- .../kernel/batched_transpose_kernel.hpp | 63 ++++++------- .../pipeline/batched_transpose_pipeline.hpp | 24 ++--- .../pipeline/batched_transpose_policy.hpp | 43 ++++----- .../pipeline/batched_transpose_problem.hpp | 15 ++-- 14 files changed, 311 insertions(+), 152 deletions(-) create mode 100755 example/ck_tile/35_batched_transpose/script/perf_test.sh create mode 100755 example/ck_tile/35_batched_transpose/script/run_full_test.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index e0ec214c69..60fe2df99d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Optimized -None +* Added Vectorize Transpose optimization for CK Tile (#2131) ### Fixes diff --git a/Jenkinsfile b/Jenkinsfile index 2ad96ed44b..68e0fa1246 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -362,6 +362,20 @@ def cmake_build(Map conf=[:]){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." } } + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + archiveArtifacts "perf_transpose_*.log" + if (arch_type == 1){ + stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ archiveArtifacts "perf_tile_gemm_**.log" @@ -698,6 +712,15 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + unstash "perf_transpose_log_gfx942" + unstash "perf_transpose_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the Transpose performance logs: ${err.getMessage()}." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ unstash "perf_tile_gemm_log_gfx942" @@ -753,7 +776,7 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true @@ -833,6 +856,10 @@ pipeline { name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, description: "Run the ck_tile FMHA tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_TRANSPOSE_TESTS", + defaultValue: false, + description: "Run the ck_tile Transpose tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", defaultValue: false, @@ -1032,6 +1059,50 @@ pipeline { } } } + stage("Run CK_TILE_TRANSPOSE Tests") + { + parallel + { + stage("Run CK_TILE_TRANSPOSE Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE_TRANSPOSE Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run CK_TILE_GEMM Tests") { parallel diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md index d0583e7529..38bb2b32e4 100644 --- a/example/ck_tile/35_batched_transpose/README.md +++ b/example/ck_tile/35_batched_transpose/README.md @@ -24,4 +24,6 @@ args: -layout_out output tensor data layout - NHWC by default -seed seed to be used, -1 means random every time (default:-1) -k_name t to 1 will print kernel name (default:0) + -warmup warmup iterations to run this kernel (default:50) + -repeat number of iterations to run this kernel (default:100) ``` \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 77d768fe3f..1eb0445c84 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "batched_transpose_example.hpp" -#include template + ck_tile::index_t thread_y, + bool kPadM, + bool kPadN> float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) { - uint32_t dim_block_h = (a.height + block_y - 1) / block_y; - uint32_t dim_block_w = (a.width + block_x - 1) / block_x; - uint32_t dim_stride = a.height * a.width; + uint32_t dim_stride = a.height * a.width; a.dim_stride = dim_stride; - a.dim_block_h = dim_block_h; - a.dim_block_w = dim_block_w; + a.dim_block_h = block_y; + a.dim_block_w = block_x; using block_tile = ck_tile::sequence; using warp_tile = ck_tile::sequence; using thread_tile = ck_tile::sequence; using ts_problem = - ck_tile::BatchedTransposeProblem; + ck_tile::BatchedTransposeProblem; using ts_pipeline = ck_tile::BatchedTransposePipeline; using kernel = ck_tile::BatchedTransposeKernel; @@ -35,25 +34,40 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); + printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); + printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); + printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + kargs.batch, + kargs.height, + kargs.width, + kargs.dim_stride); + + printf("Launching Kernel...\n"); + float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + printf("Kernel finished...\n"); + return ave_time; } // Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \ - F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \ - F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \ - F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1) +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \ - static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -62,21 +76,38 @@ float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) { - if(t.type == "fp16") + if(t.type == "fp8") { - return transpose_fn_fp16_16_16_8_8_1_1(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + } + } + else if(t.type == "fp16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); + } } else if(t.type == "bf16") { - return transpose_fn_bf16_16_16_8_8_1_1(a, s); - } - else if(t.type == "fp32") - { - return transpose_fn_fp32_16_16_8_8_1_1(a, s); - } - else if(t.type == "int8") - { - return transpose_fn_int8_16_16_8_8_1_1(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + } } return -1; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 48fc2859bf..33b6f0eacf 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -21,13 +21,13 @@ void dump_host_tensor_4d(const ck_tile::HostTensor& x) std::cout << "["; for(size_t i = 0; i < len[0]; i++) { - std::cout << i << ": ["; + std::cout << "Batch " << i << ":" << std::endl; for(size_t j = 0; j < len[1]; j++) { - std::cout << j << ": ["; + std::cout << " Channel " << j << ":" << std::endl; for(size_t k = 0; k < len[2]; k++) { - std::cout << k << ": ["; + std::cout << " Row " << k << ": "; for(size_t v = 0; v < len[3]; v++) { if constexpr(std::is_same_v) @@ -41,15 +41,15 @@ void dump_host_tensor_4d(const ck_tile::HostTensor& x) } else { - std::cout << x(std::vector{i, j, k, v}) << " "; + std::cout << static_cast(x(std::vector{i, j, k, v})) + << " "; } } - std::cout << "]" << std::endl; + std::cout << std::endl; } - std::cout << "]" << std::endl; } - std::cout << std::endl; } + std::cout << "]" << std::endl; std::cout << "--------------------" << std::endl; } #endif @@ -93,12 +93,14 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "whether do CPU validation or not") .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") - .insert("N", "2", "input batch size. ") - .insert("C", "16", "input channel size.") - .insert("H", "1", "input height size.") - .insert("W", "16", "input width size. ") + .insert("N", "1", "input batch size. ") + .insert("C", "64", "input channel size.") + .insert("H", "18", "input height size.") + .insert("W", "64", "input width size. ") .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "t to 1 will print kernel name"); @@ -115,6 +117,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) int C = args.get_int("C"); int H = args.get_int("H"); int W = args.get_int("W"); + int n_warmup = args.get_int("warmup"); + int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); int seed = args.get_int("seed"); @@ -177,7 +181,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) return a_; }(); - ck_tile::stream_config sc{nullptr, true}; + ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat}; auto ms = batched_transpose(trait, karg, sc); @@ -202,7 +206,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) layout_in.c_str(), ms); if(ms < 0) - printf("not supported\n"); + printf("------------------------------------not " + "supported-------------------------------------\n"); fflush(stdout); if(ms < 0) @@ -227,7 +232,9 @@ bool run_batched_transpose(ck_tile::ArgParser args) rtn &= ck_tile::check_err( y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("-----------------------------------------------------------------------valid:%s--------" + "--------------------------------------------------------------------\n", + rtn ? "y" : "n"); fflush(stdout); return rtn; } @@ -240,9 +247,9 @@ int main(int argc, char** argv) std::string prec = args.get_str("pr"); bool r = true; - if(prec.compare("fp32") == 0) + if(prec.compare("fp8") == 0) { - r &= run_batched_transpose(args); + r &= run_batched_transpose(args); } else if(prec.compare("fp16") == 0) { @@ -252,10 +259,6 @@ int main(int argc, char** argv) { r &= run_batched_transpose(args); } - else if(prec.compare("int8") == 0) - { - r &= run_batched_transpose(args); - } return r ? 0 : -1; } diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh new file mode 100755 index 0000000000..7ecfefc580 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_batched_transpose + +for pr in "fp8" "fp16" "bf16"; do +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' + +done \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/script/run_full_test.sh b/example/ck_tile/35_batched_transpose/script/run_full_test.sh new file mode 100755 index 0000000000..4d0c988912 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/run_full_test.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_batched_transpose executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +example/ck_tile/35_batched_transpose/script/smoke_test.sh + +#run performance benchmarks + diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index fdfef2cea8..fdc01a2eb4 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -2,10 +2,26 @@ EXE=./build/bin/tile_example_batched_transpose -for pr in "fp32" "fp16" "int8" ; do +for pr in "fp8" "fp16" "bf16"; do $EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' $EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' $EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' -done +$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' + +done \ No newline at end of file diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 32de227b52..29db5e1fca 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -384,22 +384,6 @@ struct tensor_view coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element, x); } - CK_TILE_HOST_DEVICE void print() const - { - printf("tensor_view{"); - - // buf_ - printf("buf_: "); - print(buf_); - printf(", "); - - // desc_ - printf("desc_: "); - print(desc_); - - printf("}"); - } - // member buffer_view buf_; TensorDesc desc_; @@ -494,6 +478,7 @@ template {}); constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY); @@ -103,13 +108,19 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, // loop over SFC static_for<0, num_access, 1>{}([&](auto iAccess) { // data index [y0, y1, ...] in the order of input tensor - constexpr auto idx_y = SFC_Y::get_index(iAccess); - - constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y); - constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y); - + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + constexpr auto idx_y_in = + generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number{}); + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); + constexpr auto idx_y_out_tmp = + generate_array([&](auto ii) { return idx_y_start[ii].value; }, number{}); + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); if constexpr(vec_length_in == 1) { + out_tensor.get_thread_buffer()[number{}] = in_tensor.get_thread_buffer()[number{}]; } diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index 7e7dd03c6a..4c3aa2ba29 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs index_t batch; index_t height; index_t width; - // index_t dim_blocks; index_t dim_stride; index_t dim_block_h; index_t dim_block_w; @@ -28,8 +27,10 @@ struct BatchedTransposeHostArgs template struct BatchedTransposeKernel { - using Pipeline = remove_cvref_t; - using Problem = remove_cvref_t; + + CK_TILE_DEVICE static index_t counter = 0; + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; using Type = typename Problem::InputType; @@ -46,11 +47,11 @@ struct BatchedTransposeKernel using Kargs = BatchedTransposeKargs; using Hargs = BatchedTransposeHostArgs; - CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args) { - size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w; - size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h; - size_t grid_size_z = h.batch; + size_t grid_size_x = (host_args.height + host_args.dim_block_h - 1) / host_args.dim_block_h; + size_t grid_size_y = (host_args.width + host_args.dim_block_w - 1) / host_args.dim_block_w; + size_t grid_size_z = host_args.batch; return dim3(grid_size_x, grid_size_y, grid_size_z); } @@ -70,58 +71,52 @@ struct BatchedTransposeKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { + static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput; + static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput; - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); + const auto iDim = blockIdx.z; - static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread; - static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread; - - static_assert(kMPerThread == 1 && kNPerThread == 1); - - const auto iDim = blockIdx.z; const auto x_m_n = [&]() { const auto x_dram_naive = make_naive_tensor_view( static_cast(kargs.p_input) + iDim * kargs.dim_stride, make_tuple(kargs.height, kargs.width), make_tuple(kargs.width, 1), - number{}, // TODO thread load value + number{}, number<1>{}); return pad_tensor_view(x_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); - const auto y_n_m = [&]() { const auto y_dram_naive = make_naive_tensor_view( static_cast(kargs.p_output) + iDim * kargs.dim_stride, make_tuple(kargs.width, kargs.height), make_tuple(kargs.height, 1), - number{}, + number{}, number<1>{}); return pad_tensor_view(y_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - auto x_block_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {static_cast(iM * kMPerBlock), - static_cast(iN * kNPerBlock)}); + auto x_block_window = make_tile_window( + x_m_n, + make_tuple(number{}, number{}), + {static_cast(iM), static_cast(iN)}); - auto y_block_window = - make_tile_window(y_n_m, - make_tuple(number{}, number{}), - {static_cast(iN * kNPerBlock), - static_cast(iM * kMPerBlock)}); + auto y_block_window = make_tile_window( + y_n_m, + make_tuple(number{}, number{}), + {static_cast(iN), static_cast(iM)}); Pipeline{}(x_block_window, y_block_window); } diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index aa62333918..e815313c06 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -29,24 +29,18 @@ struct BatchedTransposePipeline { auto inp_win = make_tile_window(input_window, Policy::template MakeInputDistribution()); + + auto input_tile = load_tile(inp_win); + + auto output_tile = make_static_distributed_tensor( + Policy::template MakeOutputDistribution()); + + transpose_tile2d(output_tile, input_tile); + auto out_win = make_tile_window(out_window, Policy::template MakeOutputDistribution()); - auto x = load_tile(inp_win); // x->thread input_win->block - - auto y = make_static_distributed_tensor( - Policy::template MakeOutputDistribution()); - - constexpr auto span_2d_x = decltype(x)::get_distributed_spans(); - - sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) { - sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx1, idx0); - y(i_j_idx) = x(i_j_idx); - }); - }); - - store_tile(out_win, y); + store_tile(out_win, output_tile); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index 9953e8b8bf..dd9a6d79a8 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -14,31 +14,34 @@ struct BatchedTransposePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() { - using S = Problem; - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>{}); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::kMPerBlock; + constexpr index_t NPerBlock = Problem::kNPerBlock; + constexpr index_t VecLoadSize = Problem::VectorSizeInput; + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } template CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() { - using S = Problem; - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 1>>, - sequence<2, 1>, - sequence<2, 2>>{}); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::kMPerBlock; + constexpr index_t NPerBlock = Problem::kNPerBlock; + constexpr index_t VecLoadSize = Problem::VectorSizeOutput; + + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index af6b2d51aa..fd5ea004b6 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include #include #define VectorLoadSize 16 @@ -12,11 +11,11 @@ namespace ck_tile { template + typename BlockTile, // Sequence<... + typename WarpTile, // Sequence<... + typename ThreadTile, + bool kPadM_ = false, + bool kPadN_ = false> // Sequence<... struct BatchedTransposeProblem { using InputType = remove_cvref_t; @@ -42,7 +41,7 @@ struct BatchedTransposeProblem static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO - static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1; + static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType); + static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType); }; } // namespace ck_tile From b49f7de81f35610c93129eadd2103e78bd0257d4 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 12 May 2025 09:52:58 -0700 Subject: [PATCH 114/443] Improve the general performance of the Preshuffled GEMM V3 & delete the unnecessary instances (#2166) * make the work compiled * Solved the example code, but still have the profiler error * Finished the feature * Clang format and update the CHANGELOG * solve the preshuffle v1 & v2 problem * Comment Addressed * Comment Addressed --- CHANGELOG.md | 3 + ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 9 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v1.hpp | 53 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v2.hpp | 68 +- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 708 ++++++++---------- .../blockwise_gemm_pipeline_xdlops_base.hpp | 5 + ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 16 +- include/ck/utility/blkgemmpipe_scheduler.hpp | 20 + .../gpu/gemm_multiply_multiply_wp.hpp | 389 ---------- .../gemm_multiply_multiply_wp/CMakeLists.txt | 48 -- ..._multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp | 44 +- 11 files changed, 445 insertions(+), 918 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60fe2df99d..4be173dd85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Optimized + +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) * Added Vectorize Transpose optimization for CK Tile (#2131) + ### Fixes None diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..9f758d5fc5 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -9,7 +9,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" @@ -142,12 +141,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 128, 16, 16, - 32, 32, - 2, 2, + 16, 16, + 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index d751543175..1d27a74bd7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -346,14 +349,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -409,14 +416,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -495,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 4c019a41a4..7bbaaca5b6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(I0)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -318,14 +321,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_buf)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -389,14 +396,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -445,12 +456,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -539,7 +553,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6d115e7620..6f3a7e6357 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -5,6 +5,16 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" +#define DS_READ_A_PREFETCH_STAGES 2 + +template +constexpr auto compute_stage_loads(T total_loads, T stages) +{ + return std::make_pair((total_loads + stages - 1) / stages, // ceil + total_loads / stages // floor + ); +} + namespace ck { // Compute optimized pipeline @@ -123,6 +133,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -184,298 +191,132 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + constexpr auto num_total_stages = MRepeat; - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / MRepeat; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / MRepeat; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto total_buffer_loads = num_buffer_load_inst_a + num_buffer_load_inst_b; + constexpr auto stages_available = MRepeat - DS_READ_A_PREFETCH_STAGES; + + constexpr auto stage_loads = compute_stage_loads(total_buffer_loads, stages_available); + + constexpr auto buffer_load_perstage_more = stage_loads.first; + constexpr auto buffer_load_perstage_less = stage_loads.second; + + constexpr auto buffer_load_stages_more = total_buffer_loads % stages_available; + + constexpr auto buffer_b_heavy_loads = buffer_load_perstage_more * buffer_load_stages_more; + constexpr auto buffer_b_remaining = + num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more; + + constexpr auto buffer_load_b_stages = + buffer_b_heavy_loads > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + buffer_b_remaining / buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - DS_READ_A_PREFETCH_STAGES - buffer_load_b_stages; + + static_assert(buffer_load_a_stages > 0, + "The buffer load a stages should always have a value over 0."); + + constexpr auto buffer_load_issue_point_interval_more = + math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_more); + constexpr auto buffer_load_issue_point_interval_less = + buffer_load_perstage_less == 0 + ? INT32_MAX + : math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_less); + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); } }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { -#if 0 - constexpr auto staged_num_ds_write_a_per_ds_read_a = - num_ds_write_inst_a / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; - // A local write - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { - ignore = idswrite_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - }); - - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); -#elif 1 - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); -#endif - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); } template {}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + // K = k0 × KGroup × k1 = k0 × kg0 × A_K1 + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); // Initialize C @@ -558,26 +404,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -613,49 +451,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); }; LoopFunc(I0, I1); @@ -667,20 +544,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -707,36 +578,68 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -764,25 +667,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); - // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle - // latency - // __builtin_amdgcn_sched_barrier(0); + + HotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { @@ -813,18 +720,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -841,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ce507ca8d3..6c1c5b1c4d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -58,6 +58,11 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KPerInnerLoop = KPack; + static constexpr index_t KGroup = + ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 238ab14606..c0d9464136 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -167,11 +167,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KPackPerGroup = KPack / KGroup; + static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr auto MakeDsGridPointer() { @@ -209,7 +211,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPackPerGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -351,7 +353,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1228,7 +1230,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1668,7 +1670,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 39407cb8f6..6c788fb41e 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -48,6 +48,15 @@ enum struct TailNumber // prefetchstages Full, }; + +enum SchedulerGroup : uint32_t +{ + SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions + SCHED_GROUP_VMEM = 0x020, // Global memory operations + SCHED_GROUP_LDS_READ = 0x100, // LDS read operations + SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations +}; + template , - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( @@ -612,33 +250,6 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2( diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt index 37233ac5b4..743a0272f7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt @@ -2,18 +2,6 @@ set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES) list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -21,18 +9,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -41,18 +17,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp ) -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -60,18 +24,6 @@ set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp index e5ada03a46..4613a0f24d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp @@ -171,13 +171,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly // 256x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 16, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -190,13 +190,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 224x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 14, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 14, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 14, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 14, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -208,13 +208,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 192x[64, 256, 32]x128, 192x[64]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 12, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 12, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 12, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 12, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -226,13 +226,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 10, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 10, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 10, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -244,10 +244,10 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -259,11 +259,11 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; From f05e45ba59b76cb6ea83c471860ded65d5fc623f Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Mon, 12 May 2025 09:56:23 -0700 Subject: [PATCH 115/443] Disable SMFMA gfx90a (#2184) * sparsity fix for gfx90a * reverting tile_engine changes --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 9 --------- .../ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp | 4 ++-- tile_engine/ops/gemm/gemm_instance_builder.py | 6 +----- 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 5ed97dc05c..f050a8e382 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -109,20 +109,11 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl>; // fp16 2:4 structured sparsity -#if defined(__gfx94__) || defined(__gfx95__) using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; -#else // gfx 90a does not support smfmac -using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmImpl, - 2>>; -using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmImpl, - 2>>; -#endif // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index 97fd2a8742..cd6cd3a399 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -49,7 +49,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; @@ -100,7 +100,7 @@ struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 const int32_t& idx, bool_constant = {}) const { -#if defined(__gfx9__) +#if defined(__gfx94_) or defined(__gfx95_) c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); #else ck_tile::ignore = c_vec; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c00554df8f..3839523e3d 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -535,11 +535,7 @@ struct GemmDispatcher { ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) content += f""" -#if defined(__gfx908__) - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#else - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream); -#endif""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" content += f""" }} else {{""" for tile in tile_params: From 29206047868b5a3eda88aa33ff5b997ba4e008b4 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Tue, 13 May 2025 12:19:25 +0800 Subject: [PATCH 116/443] [CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163) * hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard Co-authored-by: coderfeli --- CHANGELOG.md | 1 + .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 + .../01_fmha/codegen/ops/fmha_batch_prefill.py | 595 +++++++++ .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 1 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 72 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 57 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 17 + example/ck_tile/01_fmha/fmha_fwd.hpp | 212 +++ example/ck_tile/01_fmha/generate.py | 3 +- include/ck_tile/core.hpp | 1 + include/ck_tile/core/numeric/math.hpp | 41 + include/ck_tile/core/tensor/load_tile.hpp | 90 +- include/ck_tile/core/tensor/tensor_view.hpp | 21 + .../core/tensor/tile_scatter_gather.hpp | 731 +++++++++++ .../ck_tile/core/tensor/tile_window_utils.hpp | 7 + include/ck_tile/ops/fmha.hpp | 4 + include/ck_tile/ops/fmha/block/variants.hpp | 274 ++++ .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 1134 +++++++++++++++++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 83 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 77 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 900 +++++++++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 18 + ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 98 +- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 93 +- .../pipeline/block_fmha_pipeline_problem.hpp | 6 + .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 102 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 101 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 102 +- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 4 + 29 files changed, 4621 insertions(+), 226 deletions(-) create mode 100644 example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py create mode 100644 include/ck_tile/core/tensor/tile_scatter_gather.hpp create mode 100644 include/ck_tile/ops/fmha/block/variants.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 4be173dd85..a1163f059c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added GEMM pipeline for microscaling (MX) data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. +* Added logit soft-capping support for fMHA forward kernels. ### Optimized diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..5b9d5742b4 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -114,12 +114,14 @@ LAYOUT_MAP = { PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", } PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py new file mode 100644 index 0000000000..30b9299963 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -0,0 +1,595 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_BATCH_PREFILL_PIPELINE_MAP = { + "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" +FMHA_FWD_API=""" +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_batch_prefill_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + ### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + ### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 932f6020b6..80b64f918a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + false, {F_bias}, {F_dbias}, false, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c31a0ce954..2f1287c87a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -32,6 +32,7 @@ K0_MAX_SUBMAX_MAP = { FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" """ @@ -51,12 +52,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, false, {F_lse}, {F_dropout}, {F_squant}, {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< @@ -73,6 +78,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait_{F_idx}>; @@ -88,7 +94,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -123,9 +129,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -144,6 +150,7 @@ class FmhaFwdApiTrait: bk1 : int # tile size along kv gemm unroll bk0max : int vlayout : str + logits : str mask : str bias : str # lse : str # @@ -157,7 +164,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -165,7 +172,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -176,7 +183,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -187,7 +194,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -199,7 +206,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -214,6 +221,7 @@ class FmhaFwdPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_dropout : str # @@ -235,6 +243,9 @@ class FmhaFwdPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -280,7 +291,7 @@ class FmhaFwdApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -365,6 +376,7 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], @@ -399,6 +411,7 @@ class FmhaFwdKernel: bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, @@ -440,36 +453,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -497,6 +510,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 5ad118fd1a..3ae0e28be3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -45,6 +45,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; namespace {{ @@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, /*kHasBiasGrad=*/false, {F_lse}, @@ -85,6 +87,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, fmha_shape, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait>; @@ -113,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -267,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -310,6 +313,7 @@ class FmhaFwdSplitKVApiTrait: bk0max : int vlayout : str mask : str + logits : str bias : str # lse : str # squant : str # @@ -322,7 +326,7 @@ class FmhaFwdSplitKVApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @property @@ -380,6 +384,7 @@ class FmhaFwdSplitKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_squant : str # @@ -401,6 +406,9 @@ class FmhaFwdSplitKVPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -475,7 +483,7 @@ class FmhaFwdSplitKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -541,6 +549,7 @@ class FmhaFwdSplitKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -574,6 +583,7 @@ class FmhaFwdSplitKVKernel: bk1=self.F_tile.F_bk1, bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, @@ -671,32 +681,32 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -720,6 +730,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8f6fb8df54..bb1f495c4e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[]) "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") @@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + std::string squant_str = arg_parser.get_str("squant"); bool squant = [&]() { if(squant_str == "auto") @@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else // fmha_fwd_traits or fmha_splitkv_traits { traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.scale_p = scale_p; args.scale_o = scale_o; + args.logits_soft_cap = logits_soft_cap; + args.stride_bias = (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); args.stride_o = stride_o; @@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 765c221a7b..1838ee5bd9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -143,6 +143,8 @@ struct fmha_fwd_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -232,6 +234,8 @@ struct fmha_fwd_splitkv_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -308,6 +312,85 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_vnew; }; +struct fmha_batch_prefill_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + const void* seqstart_q_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + // SGLang-style page table + int32_t num_total_pages; + void* kv_indptr; + void* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + void* kv_last_page_lens; + ck_tile::index_t page_block_size; +#endif + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -333,6 +416,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -371,6 +455,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -443,6 +528,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.is_gappy, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -485,6 +571,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.cache_batch_idx, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -618,6 +705,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -677,6 +877,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -776,6 +978,9 @@ struct fmha_fwd_appendkv_traits_ template float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); +template +float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -784,6 +989,7 @@ struct fmha_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -800,6 +1006,7 @@ struct fmha_fwd_splitkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -821,3 +1028,8 @@ struct fmha_fwd_appendkv_traits float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); + +using fmha_batch_prefill_traits = fmha_fwd_traits; +float fmha_batch_prefill(fmha_batch_prefill_traits, + fmha_batch_prefill_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index c2b0924eb3..c611618824 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -21,8 +21,7 @@ class HandlerId(IntEnum): ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) - if full_module_name not in sys.modules: - ops.append(importer.find_spec(module_name).loader.load_module(module_name)) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) unwanted_prefix = 'fmha_' handlers = dict( [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index b94157eaec..b9791f0b55 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/tensor/tile_scatter_gather.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 6bdcb509b0..8176fe551c 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -487,6 +487,9 @@ struct log2e template constexpr T log2e_v = log2e::value; +template +constexpr T log2e_rcp_v = 1. / log2e::value; + CK_TILE_DEVICE float exp2(float x) { return exp2f(x); }; @@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp(double x) return exp(x); }; +template +CK_TILE_DEVICE T tanh_fast(T x) +{ + return type_convert((exp(2.0 * type_convert(x)) - 1.0) / + (exp(2.0 * type_convert(x)) + 1.0)); +}; + +template <> +CK_TILE_DEVICE float tanh_fast(float x) +{ + // float a = __builtin_amdgcn_sinh(x); + // float b = __builtin_amdgcn_cosh(x); + // float e = a * __builtin_amdgcn_rcpf(b); + // return e; + + float a = 2.0f * log2e_v * x; + a = __builtin_amdgcn_exp2f(a); + a = __builtin_amdgcn_rcpf(a + 1.0f); + a = 2 * a; + a = 1 - a; + return a; + + // float e, r, s, t, d; + // float a = x; + // s = abs(a); + // t = -log2e_v * 2.0f * s; + // e = __builtin_amdgcn_exp2f(t); + // d = e + 1.0f; + // r = __builtin_amdgcn_rcpf(d); + // r = e * (-r) + r; + // if (s < 4.997253418e-3f) r = a; + // union fipnr {float f; unsigned int i;}; + // fipnr r_; r_.f = r; + // fipnr a_; a_.f = a; + // { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; } + // return r; +}; + template CK_TILE_DEVICE T log(T x) { diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index b280a1725d..4601261197 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,32 +18,8 @@ namespace ck_tile { -template -CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(const tile_window_linear& tile_window, +template +CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) { @@ -51,35 +27,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(dst_tile, number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_linear& tile_window, + const TileWindow_& tile_window, number = {}, bool_constant = {}) { @@ -138,42 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, } template -CK_TILE_DEVICE auto -async_load_tile_raw(LdsTileWindow_&& lds_tile, - const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}, - bool_constant = {}) -{ - return tile_window.async_load_raw(lds_tile, - number{}, - bool_constant{}, - bool_constant{}); -} - -template CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, - const tile_window_linear& tile_window, + const TileWindow_& tile_window, number = {}, bool_constant = {}, bool_constant = {}) diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 29db5e1fca..656ce8d20d 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -210,6 +210,27 @@ struct tensor_view bool_constant{}); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements_raw(remove_cvref_t* smem, + const TensorCoord& coord, + index_t coord_extra_offset, + index_t linear_offset, + bool_constant = {}) const + { + return buf_.template async_get_raw( + smem, + (coord.get_offset() + coord_extra_offset) / PackedSize, + linear_offset / PackedSize, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + template +struct tile_scatter_gather +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static_assert(NumCoord == 1); + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_scatter_gather::get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr tile_scatter_gather() = default; + + CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution, + const PageIdxArray& page_idx) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + page_idx_{page_idx}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0; + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + bool_constant{}); +#if 1 + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // TODO: currently async load only implemented in inline asm + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + // printf("off %d\n", page_idx_[I0]); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number<0>{}]; + const auto page_offset = page_idx_[idx_gather]; + + // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + + // read from distributed tensor + // vector_type_t vec; + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + vec_value.template get_as()(j / Traits::PackedSize) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // const vector_t vec_value = vec.template get_as().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + // printf("coord_offset:%d, scatter_offset:%d \n", + // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + BottomTensorIndex step_new = step; + step_new(HsGatherDim) = 0; + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step_new); + }); + } + + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) + { + page_idx_ = new_idx; + + // static_for<0, 2, 1>{}([&](auto k0) { + // printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]); + // }); + } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0; + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + PageIdxArray page_idx_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + HsGatherDim, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx, + number{}); +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx, + number{}); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_utils.hpp b/include/ck_tile/core/tensor/tile_window_utils.hpp index 71a72329f8..f8b232a7af 100644 --- a/include/ck_tile/core/tensor/tile_window_utils.hpp +++ b/include/ck_tile/core/tensor/tile_window_utils.hpp @@ -18,6 +18,13 @@ #pragma once namespace ck_tile { +template +CK_TILE_DEVICE void move_tile_window(TileWindow_& window, + const typename TileWindow_::BottomTensorIndex& step) +{ + window.move(step); +} + // input a lds store tile, extract some information from it // used to set m0 value for gfx9 serious template diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index a28b63f813..ac6ef9cae3 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -9,12 +9,16 @@ #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" #include "ck_tile/ops/fmha/block/page_block_navigator.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp new file mode 100644 index 0000000000..90fc5656fc --- /dev/null +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include + +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 + +#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH +#endif + +namespace ck_tile { + +template +struct StandardAttentionParams +{ + __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_) + : impl_mask(impl_mask_), sm_scale(sm_scale_) + { + } + + const ImplMask& impl_mask; + float sm_scale; +}; + +template +struct LogitsSoftCapParams +{ + __device__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap); + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __host__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_, + float sm_scale_, + float logits_soft_cap_, + float logits_soft_cap_rcp_) + : impl_mask(impl_mask_), + sm_scale(sm_scale_), + logits_soft_cap(logits_soft_cap_), + logits_soft_cap_rcp(logits_soft_cap_rcp_) + { + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + const ImplMask& impl_mask; + float sm_scale; + float logits_soft_cap; + float logits_soft_cap_rcp; +}; + +struct StandardAttention +{ + __device__ __host__ StandardAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +template +struct LogitsSoftCap +{ + __device__ __host__ LogitsSoftCap() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(UseExp2) + { + return q; + } + else + { + return type_convert(q) * params.sm_scale; + } + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return params.sm_scale * type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +constexpr uint32_t CUSTOM_MASK = 1U; +constexpr uint32_t SLIDING_WINDOW = 2U; +constexpr uint32_t LOGITS_SOFT_CAP = 4U; +constexpr uint32_t ALIBI = 8U; + +template +struct ComposedAttention +{ + static constexpr bool use_exp2 = UseExp2; + + static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; + + __device__ __host__ ComposedAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(use_logits_soft_cap && UseExp2) + { + return q; + } + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(use_logits_soft_cap) + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return params.sm_scale * type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp new file mode 100644 index 0000000000..ba327ee511 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -0,0 +1,1134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaBatchPrefillWithPagedKVCacheKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + + int32_t num_total_pages; + const int32_t* kv_indptr; + const int32_t* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + const int32_t* kv_last_page_lens; + ck_tile::index_t page_block_size; +#else + static constexpr ck_tile::index_t page_block_size = 1; +#endif + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdDropoutSeedOffset + { + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; + } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + }; + + using Kargs = std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, +#if 0 // we assume page_block_size=1 for now + const void* kv_last_page_lens, + ck_tile::index_t page_block_size, +#endif + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + -1, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_total_pages, + reinterpret_cast(kv_indptr), + reinterpret_cast(kv_page_indices), +#if 0 // we assume page_block_size=1 for now + reinterpret_cast(kv_last_page_lens), + page_block_size, +#endif +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, +#if 0 // we assume page_block_size=1 for now + const void* kv_last_page_lens, + ck_tile::index_t page_block_size, +#endif + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_total_pages, + reinterpret_cast(kv_indptr), + reinterpret_cast(kv_page_indices), +#if 0 // we assume page_block_size=1 for now + reinterpret_cast(kv_last_page_lens), + page_block_size, +#endif +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + reinterpret_cast(seqstart_q_ptr), + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + if constexpr(kIsGroupMode) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + if constexpr(kIsGroupMode) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; +#if 0 // we assume page_block_size=1 for now + const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; +#endif + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + + kargs.kv_page_indices += kargs.kv_indptr[i_batch]; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + +#if 0 // we assume page_block_size=1 for now + kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; +#else + kargs.seqlen_k = num_page_blocks; +#endif + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + + kargs.kv_page_indices += kargs.kv_indptr[i_batch]; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + +#if 0 // we assume page_block_size=1 for now + kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; +#else + kargs.seqlen_k = num_page_blocks; +#endif + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple( + make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 1202524950..a4b3765455 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" #include #include @@ -47,11 +48,13 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; @@ -94,7 +97,7 @@ struct FmhaFwdKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ @@ -139,6 +142,28 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o; }; + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct FmhaFwdCommonBiasKargs { const void* bias_ptr = nullptr; @@ -242,7 +267,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -260,7 +286,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -269,6 +296,13 @@ struct FmhaFwdKernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -287,6 +321,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -343,6 +378,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, @@ -398,6 +434,10 @@ struct FmhaFwdKernel kargs.batch_stride_randval = batch_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -421,6 +461,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -465,6 +506,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -512,6 +554,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -556,6 +599,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -603,6 +647,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -652,6 +697,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -703,6 +749,10 @@ struct FmhaFwdKernel kargs.nhead_stride_randval = nhead_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -727,6 +777,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -765,6 +816,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -806,6 +858,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -844,6 +897,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -1307,6 +1361,21 @@ struct FmhaFwdKernel } }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { @@ -1328,6 +1397,9 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } @@ -1342,6 +1414,9 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index ea1762abc1..63011d2ba9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -6,6 +6,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + #include #include @@ -43,14 +45,15 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; - - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static_assert(!kMergeNumHeadGroupsSeqLenQ || @@ -95,7 +98,7 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ @@ -150,6 +153,28 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc; }; + struct LogitsSoftCapKargs + { + LogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct CommonBiasKargs { const void* bias_ptr = nullptr; @@ -207,7 +232,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t + std::conditional_t, + std::conditional_t> { const int32_t* seqlen_k_ptr; @@ -229,7 +255,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -243,6 +270,13 @@ struct FmhaFwdSplitKVKernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -268,6 +302,7 @@ struct FmhaFwdSplitKVKernel const void* cache_batch_idx, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -324,6 +359,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table or cache_batch_idx + {}, // placeholder for logits_soft_cap reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, @@ -363,6 +399,10 @@ struct FmhaFwdSplitKVKernel { kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -392,6 +432,7 @@ struct FmhaFwdSplitKVKernel bool is_gappy, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -444,6 +485,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), @@ -478,6 +520,10 @@ struct FmhaFwdSplitKVKernel kargs.page_block_size = page_block_size; kargs.is_gappy = is_gappy; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -968,6 +1014,21 @@ struct FmhaFwdSplitKVKernel } }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead_k}; + auto o_acc_tile = [&, i_split_ = i_split]() { if constexpr(kDoFp8StaticQuant) { @@ -991,6 +1052,9 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } @@ -1008,6 +1072,9 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000..e07cf1c94e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaBatchPrefillPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dist = Policy::template MakeKDramTileDistribution(); + auto k_coord = k_dist.calculate_index(); + using KDstrEncode = typename decltype(k_dist)::DstrEncode; + constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; + statically_indexed_array k_offsets; + static_for<0, NRepeat, 1>{}([&](auto n0) { + k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; + }); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + k_dist, + k_offsets); // K DRAM tile window for + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dist = Policy::template MakeVDramTileDistribution(); + auto v_coord = v_dist.calculate_index(); + const auto VPageIndexDim = I1; + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; + statically_indexed_array v_offsets; + (void)stride_k; + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + + auto v_dram_window = + make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + v_dist, + v_offsets, + VPageIndexDim); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = + page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + + v_coord[VPageIndexDim] + k0.value] * + stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + page_idx += kN0; + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; + }); + k_dram_window.update_page_idx(k_offsets); + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k, + stride_v, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000..02731ca8f8 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 809c58f1d1..4d1c38e079 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -27,6 +27,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -46,15 +47,21 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -128,7 +135,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -150,6 +159,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -453,9 +465,34 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); @@ -574,7 +611,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx])); @@ -603,8 +647,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -711,7 +762,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -757,7 +815,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile @@ -771,6 +831,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -794,6 +857,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ce80dba5eb..7f5f79d7a7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -26,6 +26,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -45,15 +46,21 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,7 +134,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -149,6 +158,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -401,9 +413,28 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); @@ -497,7 +528,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -522,8 +560,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -620,7 +666,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -662,7 +715,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile @@ -676,6 +731,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -699,6 +757,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9a5208c025..f35c00c268 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -20,6 +20,7 @@ template struct BlockFmhaPipelineProblem @@ -36,6 +37,7 @@ struct BlockFmhaPipelineProblem using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; @@ -50,6 +52,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; @@ -69,6 +72,7 @@ template struct BlockFmhaFwdSplitKVPipelineProblem @@ -84,6 +88,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; @@ -98,6 +103,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8a4a925b81..29f183c613 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -5,8 +5,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -47,14 +48,20 @@ struct BlockFmhaPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -101,7 +108,7 @@ struct BlockFmhaPipelineQRKSVS else { return 1; - }; + } } }(); @@ -128,7 +135,9 @@ struct BlockFmhaPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -147,6 +156,9 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -380,9 +392,28 @@ struct BlockFmhaPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -398,7 +429,12 @@ struct BlockFmhaPipelineQRKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -450,7 +486,14 @@ struct BlockFmhaPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -475,8 +518,16 @@ struct BlockFmhaPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -574,7 +625,14 @@ struct BlockFmhaPipelineQRKSVS } else { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -614,7 +672,9 @@ struct BlockFmhaPipelineQRKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -625,6 +685,9 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -645,6 +708,9 @@ struct BlockFmhaPipelineQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 67354fc72d..7af3902dc5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVSAsync using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -53,13 +54,19 @@ struct BlockFmhaPipelineQRKSVSAsync // only need special care about seq_k padding (oob need set -INF of p instead of zero) static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && Problem::kPadHeadDimV == true); - static constexpr bool kPadSeqLenQ = true; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -153,7 +160,9 @@ struct BlockFmhaPipelineQRKSVSAsync typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -172,6 +181,9 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -435,9 +447,34 @@ struct BlockFmhaPipelineQRKSVSAsync else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -454,7 +491,12 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -543,7 +585,14 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -568,8 +617,15 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -695,7 +751,14 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -735,7 +798,9 @@ struct BlockFmhaPipelineQRKSVSAsync typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -746,6 +811,9 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -766,6 +834,9 @@ struct BlockFmhaPipelineQRKSVSAsync mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 7be6a347f5..4efcd871dc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -27,6 +28,7 @@ struct BlockFmhaPipelineQSKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -44,14 +46,21 @@ struct BlockFmhaPipelineQSKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = @@ -95,7 +104,9 @@ struct BlockFmhaPipelineQSKSVS return 1; } else + { return 1; + } } }(); @@ -122,7 +133,9 @@ struct BlockFmhaPipelineQSKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -141,6 +154,9 @@ struct BlockFmhaPipelineQSKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& /* unused_dropout */) const { @@ -380,9 +396,28 @@ struct BlockFmhaPipelineQSKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -398,7 +433,12 @@ struct BlockFmhaPipelineQSKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -450,7 +490,14 @@ struct BlockFmhaPipelineQSKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -481,8 +528,16 @@ struct BlockFmhaPipelineQSKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -571,7 +626,14 @@ struct BlockFmhaPipelineQSKSVS } else { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -611,7 +673,9 @@ struct BlockFmhaPipelineQSKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -622,6 +686,9 @@ struct BlockFmhaPipelineQSKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -642,6 +709,9 @@ struct BlockFmhaPipelineQSKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 8d2d848558..4530b58d85 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -13,6 +13,7 @@ template 1 or fwd training is running */ @@ -51,6 +54,7 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; From c53b7bd22e75c69beddb6ffefc22b5f95354ffba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 13 May 2025 10:14:30 +0200 Subject: [PATCH 117/443] Switch to v2 pipeline for grouped conv bwd data (#2181) * Change to old pipeline for grouped conv bwd data * fix * fix * fix * fix * fix * fix * Fix --- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 29 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 30 +- .../device_batched_gemm_e_permute_xdl.hpp | 28 +- .../impl/device_batched_gemm_multi_d_xdl.hpp | 30 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 30 +- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 30 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 5 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 1052 ++--------------- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 28 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 21 +- .../device/impl/device_grouped_gemm_xdl.hpp | 3 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 81 +- 12 files changed, 256 insertions(+), 1111 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 00518b369f..72c011bfb2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,7 +153,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_batch_offset, p_bs_grid + b_batch_offset, p_ds_grid_grp, @@ -439,7 +439,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -454,11 +454,26 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - using GridwiseGemm = - ck::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = ck::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index d53fbca4ea..fc1a2b995a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,19 +80,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -556,7 +557,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 25a9d7f96d..0cd1d84a43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -88,19 +88,20 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -344,7 +345,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 630f143260..12085edaae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -107,19 +107,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -336,7 +337,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -324,7 +325,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 3fae3a3765..6c4195e75d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,19 +57,20 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -257,7 +258,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD( + GridwiseGemm::template Run( contraction_arg_ptr[group_id].p_a_grid_, contraction_arg_ptr[group_id].p_b_grid_, contraction_arg_ptr[group_id].p_ds_grid_, @@ -368,7 +368,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 41f596d160..f18ce40fc5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -15,7 +15,6 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" @@ -71,7 +70,8 @@ template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum OutElementOp> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -92,12 +92,14 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -123,19 +125,22 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + KBatch, + k_idx); #else ignore = p_a_grid; ignore = p_b_grid; @@ -154,151 +159,6 @@ __global__ void #endif } -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t num_k_per_block) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); - const index_t k_idx = - __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); -#else - ignore = karg; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = compute_ptr_offset_of_batch; - ignore = compute_ptr_offset_of_n; - ignore = num_k_per_block; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t num_k_per_block) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); - const index_t k_idx = - __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); -#else - ignore = karg; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = compute_ptr_offset_of_batch; - ignore = compute_ptr_offset_of_n; - ignore = num_k_per_block; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) -} - } // namespace // Conv backward data multiple D: @@ -358,9 +218,7 @@ template + index_t MaxTransposeTransferOutScalarPerVector = 1> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD 0; static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; static constexpr bool IsSplitKSupported = (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && @@ -473,59 +330,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm #define GridwiseGemmMultiDTemplateParams \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ - AElementwiseOp, BElementwiseOp, CDEElementwiseOp, InMemoryDataOperationEnum::Set, \ - NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType - -#define GridwiseGemmTemplateParams \ - tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, \ - ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementwiseOp, \ - BElementwiseOp, CDEElementwiseOp, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, \ - AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeType, BComputeType - - using GridwiseGemm = - std::conditional_t, - GridwiseGemm_xdl_cshuffle_v3>; + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - if constexpr(isMultiD) - { - return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n); - } - else - { - const index_t M = e_grid_desc_m_n.GetLength(I0); - const index_t N = e_grid_desc_m_n.GetLength(I1); - return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n, - GridwiseGemm::CalculateMBlock(M), - GridwiseGemm::CalculateNBlock(N)); - } + return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); } template @@ -850,46 +673,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - if constexpr(isMultiD) - { - a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); - b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); - ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); - e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); - } + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); // desc for blockwise copy a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); - if constexpr(isMultiD) + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + k_batch_)) { - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - block_2_etile_map_container_.push_back(block_2_etile_map); + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - - GridwiseGemm:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); - - e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n)); - } - } - else - { - // there is no need to check since M, N, K are padded e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n)); @@ -1083,12 +894,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; + template float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; const index_t gdy = arg.num_group_; - const index_t gdz = arg.num_workgroups_per_Conv_N_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -1117,7 +929,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { throw std::runtime_error("wrong! device_op has invalid setting"); } @@ -1145,7 +958,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 Block2ETileMap, ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch, - has_main_loop>; + has_main_loop, + ElementOp>; return launch_and_time_kernel( stream_config, @@ -1166,10 +980,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.block_2_etile_map_container_[i], arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_); + arg.compute_ptr_offset_of_n_, + arg.k_batch_); }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_)) { ave_time += launch_kernel(integral_constant{}); } @@ -1182,678 +997,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } - float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - float ave_time = 0; - - const ADataType* p_a_grid = arg.p_a_grid_; - const BDataType* p_b_grid = arg.p_b_grid_; - EDataType* p_e_grid = arg.p_e_grid_; - - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) - { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - } - - constexpr index_t minimum_occupancy = - BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) - { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); - - const auto num_k_per_block = - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(Number<0>{}) / arg.k_batch_; - - // gdy is for the kbatch and num_workgrups_per_Conv_N - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - GemmM, GemmN, arg.k_batch_ * arg.num_workgroups_per_Conv_N_, arg.num_group_); - - index_t k_grain = arg.k_batch_ * KPerBlock; - index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; - const bool has_main_k_block_loop = - GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - typename GridwiseGemm::Argument gemm_arg{ - p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; - ck::utility::RotatingMemWrapper - rotating_mem(gemm_arg_, - stream_config.rotating_count, - gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), - gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - }; - - ave_time += ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg_, - arg.a_grid_desc_ak0_m_ak1_container_[i], - arg.b_grid_desc_bk0_n_bk1_container_[i], - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - num_k_per_block); - } - else - { - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_ak0_m_ak1_container_[i], - arg.b_grid_desc_bk0_n_bk1_container_[i], - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - num_k_per_block); - } - }; - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::One) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Two) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Six) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Two) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Six) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } - } - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - } - return ave_time; - } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1940,14 +1083,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, std::array{0}); } - - if constexpr(isMultiD) + if(arg.k_batch_ > 1) { - ave_time += RunMultiDGemm(arg, stream_config); + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } } else { - ave_time += RunGemmV3(arg, stream_config); + ave_time += RunMultiDGemm(arg, stream_config); } // Transpose from NHWGC to NGCHW @@ -2031,29 +1177,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; - if constexpr(!isMultiD) - { - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) - { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); - - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) - { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) - { - return false; - } - } - } - } - // Specifialization if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) @@ -2156,16 +1279,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // Gridwise GEMM size for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { - if constexpr(isMultiD) + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) - { - return false; - } + return false; } } @@ -2322,17 +1443,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { auto str = std::stringstream(); - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - // clang-format off str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" << "<" @@ -2350,11 +1460,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << "BlkGemmPipelineScheduler: " - << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " - << "BlkGemmPipelineVersion: " - << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]; + << CShuffleNXdlPerWavePerShuffle; if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c0148c3b9c..27da1d91a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -179,7 +179,7 @@ __global__ void const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_group_offset + a_n_offset, p_bs_grid + b_group_offset, p_ds_grid_grp, @@ -434,7 +434,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -450,11 +450,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm - using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 3c34d77cc9..94a4e0da4c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -89,7 +89,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, Tuple<>{}, @@ -350,16 +350,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor #define GridwiseGemmTemplateParameters \ ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index aa70a24fc1..cbee4e09f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -65,7 +65,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_ptr[group_id].a_ptr_, gemm_desc_ptr[group_id].b_ptr_, gemm_desc_ptr[group_id].ds_ptr_, @@ -242,7 +242,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -550,6 +554,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return; } + const index_t num_k_per_block = + __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -591,7 +598,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -688,7 +695,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + (KPerBlock * k_batch)); gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -943,6 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } }; From 58f9e9ffbc190188f85895deb952cb47cc89c403 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 13 May 2025 10:18:14 -0700 Subject: [PATCH 118/443] Update the buffer load/store intrinsic names for clang>=20. (#2192) * fix the buffer load/store intrinsic names * fix clang format --- .../ck_tile/18_flatmm/run_flatmm_example.inc | 74 +- .../amd_buffer_addressing_builtins.hpp | 20 +- include/ck_tile/core.hpp | 1 - .../arch/amd_buffer_addressing_builtins.hpp | 2559 ----------------- include/ck_tile/core/tensor/buffer_view.hpp | 4 - 5 files changed, 51 insertions(+), 2607 deletions(-) delete mode 100644 include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 15a9df2c0c..c191fff7d0 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -4,14 +4,22 @@ #include template -constexpr const char* DataTypeToString() { - if constexpr (std::is_same_v) { +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { return "fp16"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return "fp8"; - } else if constexpr (std::is_same_v) { + } + else if constexpr(std::is_same_v) + { return "bf8"; - } else { + } + else + { return "unknown"; } } @@ -112,8 +120,9 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = flatmm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + float ave_time = + flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -121,18 +130,15 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() << " M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } -template +template int run_flatmm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -147,7 +153,7 @@ int run_flatmm_example_with_layouts(int argc, using BDataType = typename GemmBasicTypeConfig::BDataType; using CDataType = typename GemmBasicTypeConfig::CDataType; using AccDataType = typename GemmBasicTypeConfig::AccDataType; - + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -182,7 +188,7 @@ int run_flatmm_example_with_layouts(int argc, c_rslt_host.SetZero(); // do pre-shuffle - std::string mfma = arg_parser.get_str("prec"); + std::string mfma = arg_parser.get_str("prec"); #if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) ck_tile::index_t mfma_type = 1; #else @@ -193,18 +199,18 @@ int run_flatmm_example_with_layouts(int argc, b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); invoke_flatmm( - a_dev_buf, - b_shuffle_dev_buf, - c_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_dev_buf.FromDevice(c_rslt_host.data()); bool pass = true; @@ -219,8 +225,9 @@ int run_flatmm_example_with_layouts(int argc, a_host, b_origin_host, c_ref_host); const float max_accumulated_value = *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_rslt_host, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, c_ref_host, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -277,8 +284,9 @@ int run_flatmm_example_with_layouts(int argc, c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); const float max_accumulated_value = *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_rslt_host, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index 19869906dc..296c1d44d7 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -80,7 +80,7 @@ __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); // buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( @@ -88,7 +88,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); // buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( @@ -96,15 +96,15 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); // buffer atomic-add fp32 -__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64( - double vdata, - int32x4_t rsrc, // dst_wave_buffer_resource - int voffset, // dst_thread_addr_offset - int soffset, // dst_wave_addr_offset - int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32"); +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); // memory coherency bit for buffer store/load instruction // check ISA manual for each GFX target @@ -827,7 +827,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t voffset, index_t soffset, index_t offset, - index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32"); + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); #ifndef __HIPCC_RTC__ template diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index b9791f0b55..2ea8bf15a7 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -9,7 +9,6 @@ #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/static_encoding_pattern.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp deleted file mode 100644 index 0b9956cd01..0000000000 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ /dev/null @@ -1,2559 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/core/numeric/integral_constant.hpp" -#include "ck_tile/core/numeric/vector_type.hpp" -#include "ck_tile/core/container/container_helper.hpp" -#include "ck_tile/core/container/thread_buffer.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/utility/functional.hpp" - -namespace ck_tile { - -// 128 bit SGPRs to supply buffer resource in buffer instructions -// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions -struct __attribute__((packed)) buffer_resource -{ - const void* ptr; - uint32_t range; - uint32_t config; -}; - -CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) -{ - buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; - int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); - return r; -} - -namespace impl { -// below type indicate the data type used for buffer load inline asm -// clang-format off -template struct buffer_load_trait; - -template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; -template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; -template struct buffer_load_trait<4 , T> { using payload_t = float; }; -template struct buffer_load_trait<2 , T> { using payload_t = float; }; -template struct buffer_load_trait<1 , T> { using payload_t = float; }; - -#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA -template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; -template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; -template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; -#endif -// clang-format on -} // namespace impl - -// TODO: glc/slc/... -template -struct buffer_load; -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" -// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type -// (exp_vector_type(xxx)) -template -struct buffer_load<16, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 16); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<8, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 8); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<4, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<2, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<1, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load_if; - -template -struct buffer_load_if<16, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 16); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - static_assert(sizeof(mbuf_t) == sizeof(T)); - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<8, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 8); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<4, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<2, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<1, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; -#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" -template -struct buffer_store; - -template <> -struct buffer_store<16> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; - asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<8> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; - asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<4> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<2> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 2); - using mbuf_t = short; - asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<1> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_store_if; - -template <> -struct buffer_store_if<16> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 16); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<8> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 8); - auto save_exec = __builtin_amdgcn_read_exec(); - // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch - using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<4> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<2> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 2); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<1> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); -} - -template -struct buffer_atomic_add_if; - -template -struct buffer_atomic_add_if -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(v_offset), - "v"(bit_cast(value)), - "s"(res.xy), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template -struct buffer_atomic_add; - -template -struct buffer_atomic_add -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag = 1*/) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3" - : - : "v"(v_offset), "v"(bit_cast(value)), "s"(res.xy), "n"(i_offset) - : "memory"); - } -}; - -namespace impl { -// below type indicate the data type used for buffer load inline asm -// clang-format off -template struct smem_load_trait; - -template struct smem_load_trait<16, T> { using payload_t = fp32x4_t; }; -template struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; }; -template struct smem_load_trait<4 , T> { using payload_t = float; }; -template struct smem_load_trait<2 , T> { using payload_t = float; }; -template struct smem_load_trait<1 , T> { using payload_t = float; }; - -// clang-format on -} // namespace impl - -// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :) -template -struct smem_load; - -template <> -struct smem_load<16> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 16); - using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t; - asm volatile("ds_read_b128 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<8> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 8); - using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t; - asm volatile("ds_read_b64 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<4> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t; - asm volatile("ds_read_b32 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<2> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; - asm volatile("ds_read_u16 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<1> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; - asm volatile("ds_read_u8 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -// clang-format off -namespace impl{ - -// can't use "+v" since there could be potential extra move(read/write) -// use "v" can help remove such duplicated moves -// besides, fake this as "memory" operation to force later valu after this fence -// TODO: may have scratch (because this is memory?) -// need to reduce extra move inside compiler -template -CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) -{ - constexpr auto kSize = remove_cvref_t::size(); - static_for<0, kSize, 1>{}([&](auto i){ - asm volatile(" " : : "v"(b.get(number{})) : "memory"); - }); -} -#if 1 -// below specialization just merge size() of dwords into single section -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), - "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), - "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), - "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), - "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})), - "v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})), - "v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})), - "v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})), - "v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory"); -} -#endif -CK_TILE_DEVICE void insert_dummy_dep() {} - -template -CK_TILE_DEVICE void insert_dummy_dep(T & buffer) -{ - // TODO: indeed we expect T to be multiple of dword. subdword is always buggy - using da_type = array; - auto & dummy = reinterpret_cast(buffer); - insert_dummy_dep_per_dword(dummy); -} - -template -CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) -{ - insert_dummy_dep(bx); - insert_dummy_dep(by...); -} -} -// clang-format on -template -CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); - impl::insert_dummy_dep(o...); -} - -CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -// buffer load i8 -CK_TILE_DEVICE_EXTERN int8_t -llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8.v4i32"); - -CK_TILE_DEVICE_EXTERN int8x2_t -llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8.v4i32"); - -CK_TILE_DEVICE_EXTERN int8x4_t -llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8.v4i32"); - -// buffer load i16 -CK_TILE_DEVICE_EXTERN int16_t -llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN int16x2_t -llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN int16x4_t -llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16.v4i32"); - -// buffer load i32 -CK_TILE_DEVICE_EXTERN int32_t -llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32.v4i32"); - -CK_TILE_DEVICE_EXTERN int32x2_t -llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32.v4i32"); - -CK_TILE_DEVICE_EXTERN int32x4_t -llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32.v4i32"); - -// buffer load fp16 -CK_TILE_DEVICE_EXTERN _Float16 -llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16.v4i32"); - -CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16.v4i32"); - -CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16.v4i32"); - -// buffer load fp32 -CK_TILE_DEVICE_EXTERN float -llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32.v4i32"); - -CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32.v4i32"); - -CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32.v4i32"); - -// buffer store i8 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8.v4i32"); - -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8.v4i32"); - -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8.v4i32"); - -// buffer store i16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x2( - int16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x4( - int16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32"); - -// buffer store i32 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32.v4i32"); - -// buffer store ui16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2( - uint16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4( - uint16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2( - int32x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4( - int32x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32.v4i32"); - -// buffer store fp16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2( - fp16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4( - fp16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16.v4i32"); - -// buffer store fp32 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_fp32(float vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2( - fp32x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4( - fp32x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32.v4i32"); - -// buffer atomic-add fp16 -CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - fp16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32"); - -// buffer atomic-add i32 -CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( - int32_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32"); - -// buffer atomic-add fp32 -CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32( - float vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32"); - -// buffer atomic-max fp64 -CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64( - double vdata, - int32x4_t rsrc, // dst_wave_buffer_resource - int voffset, // dst_thread_addr_offset - int soffset, // dst_wave_addr_offset - int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32"); - -// Direct loads from global to LDS. -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, - index_t size, - index_t voffset, - index_t soffset, - index_t offset, - index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32"); - -template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) -{ - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); - else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); -} - -CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -// memory coherency bit for buffer store/load instruction -// check ISA manual for each GFX target -// e.g. for -// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, -// page 67~68 -enum struct amd_buffer_coherence_enum -{ - coherence_default = 0, // default value - glc = 1, - slc = 2, - glc_slc = 3, -}; - -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) -{ - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, - "wrong! not implemented"); - - using rtn_type = thread_buffer; - - if constexpr(N == 1) - { - return bit_cast(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 4) - { - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 8) - { - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 16) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - return bit_cast(tmp); - } - else if constexpr(N == 32) - { - int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - int32x4_t tmp1 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = tmp0; - tmp.template get_as()(number<1>{}) = tmp1; - - return bit_cast(tmp); - } - else if constexpr(N == 64) - { - int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - int32x4_t tmp1 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - int32x4_t tmp2 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int32_t), - static_cast(coherence)); - int32x4_t tmp3 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int32_t), - static_cast(coherence)); - - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = tmp0; - tmp.template get_as()(number<1>{}) = tmp1; - tmp.template get_as()(number<2>{}) = tmp2; - tmp.template get_as()(number<3>{}) = tmp3; - - return bit_cast(tmp); - } -} - -#ifndef BUFFER_LOAD_USE_INLINEASM -#define BUFFER_LOAD_USE_INLINEASM 0 -#endif - -template -CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) -{ - static_assert( - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), - "wrong! not implemented"); - - using rtn_type = thread_buffer; - - if constexpr(std::is_same::value) // fp32 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - return tmp; - } - else if constexpr(N == 16) - { - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<2>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<3>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(float), - static_cast(coherence)); - - return tmp; - } - } - else if constexpr(std::is_same::value) // fp16 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - // use fp32 load to mimic fp16 load - fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else if constexpr(std::is_same::value) // bf16 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else // other datatype - { - auto raw_data = amd_buffer_load_impl_with_bytes( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); - - return bit_cast(raw_data); - } -} - -template -CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_linear_addr_offset, - index_t flag = 0, - bool_constant = {}) -{ - constexpr index_t bytes = sizeof(T) * N; - static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, - "wrong! not supported by buffer_load instruction"); - - using type = thread_buffer; - if constexpr(oob_conditional_check) - { - buffer_load_if{}(dst, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_linear_addr_offset, - flag, - bool_constant{}); - } - else - { - buffer_load{}(dst, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_linear_addr_offset, - flag, - bool_constant{}); - } -} - -template -CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - bool_constant = {}) -{ - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); -} - -template -CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - index_t flag = 0, - bool_constant = {}) -{ - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, - "wrong! not implemented"); - - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i8(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 16) - { - llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 32) - { - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); - } - else if constexpr(N == 64) - { - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 8, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 12, - static_cast(coherence)); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert( - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); - - if constexpr(std::is_same::value) // fp32 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_fp32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - } - } - else if constexpr(std::is_same::value) // fp16 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { -#if 0 - thread_buffer tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(fp16_t), - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - } - else if constexpr(std::is_same::value) // bf16 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bf16_t), - static_cast(coherence)); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_ui16x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_ui16x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(uint16_t), - static_cast(coherence)); - } - } - else - { - using r_t = thread_buffer; - - amd_buffer_store_impl_with_bytes(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset, - index_t dst_linear_addr_offset, - index_t is_valid_element = 1) -{ - constexpr index_t bytes = sizeof(T) * N; - static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, - "wrong! not supported by buffer_store instruction"); - - using type = thread_buffer; - if constexpr(oob_conditional_check) - { - buffer_store_if{}(dst_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - dst_linear_addr_offset, - is_valid_element); - } - else - { - buffer_store{}(dst_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - dst_linear_addr_offset); - } -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)) || - (std::is_same::value && (N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4)), - "wrong! not implemented"); - - if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(float), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(float), - 0); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - static_for<0, 2, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - src_thread_data.template get_as()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(fp16x2_t), - 0); - }); - } - else if constexpr(N == 8) - { - static_for<0, 4, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - src_thread_data.template get_as()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(fp16x2_t), - 0); - }); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(int32_t), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(int32_t), - 0); - } - } -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)), - "wrong! not implemented"); - if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(double), - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(double), - 0); - } - } -} - -// buffer_load requires: -// 1) p_src_wave must point to global memory space -// 2) p_src_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -// oob_conditional_check : dynamic check if out-of-bound -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = [&]() { - if constexpr(oob_conditional_check) - return src_thread_element_valid ? 0 : 0x80000000; - else - return 0; - }(); - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); -#else - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; - else - return tmp; -#endif -} - -// buffer_load requires: -// 1) p_src_wave must point to global memory space -// 2) p_src_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size, - T customized_value) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{customized_value}; - else - return tmp; -} - -template -CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t src_element_space_size, - index_t is_valid_element = 0, - bool_constant = {}) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_buffer_load_raw_impl( - dst, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t is_valid_element = 0, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_buffer_load_raw_impl( - dst, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// unfortunately async copy can not make sure invalid data is zero inside LDS -// ... unless people manually write zero to LDS at the proper address. -// so not support invalid_element check for now. -// buffer_load OOB still working. -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t src_element_space_size, - bool_constant = {}) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load_impl(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load_impl(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - bool is_valid_element, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// buffer_store requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_store(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = [&]() { - if constexpr(oob_conditional_check) - return dst_thread_element_valid ? 0 : 0x80000000; - else - return 0; - }(); - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if constexpr(oob_conditional_check) - { - if(dst_thread_element_valid) - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const index_t dst_linear_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); - - amd_buffer_store_raw_impl(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - dst_thread_element_valid); -} - -// buffer_atomic_add requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) - { - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const index_t dst_linear_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size, - bool_constant = {}) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); - - if constexpr(oob_conditional_check) - { - buffer_atomic_add_if{}(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - dst_thread_element_valid); - } - else - { - buffer_atomic_add{}(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - 1); - } -} - -// buffer_atomic_max requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_atomic_max_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) - { - amd_buffer_atomic_max_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, - const index_t global_offset, - T* lds_base_ptr, - const index_t lds_offset, - const bool is_valid, - const index_t src_element_space_size) -{ - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; - static_assert(bytes_per_thread == dword_bytes); - - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); - const int32x4_t src_resource = - make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T)); - const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; - -#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM - T* lds_ptr = lds_base_ptr + lds_offset; - auto const lds_ptr_sgpr = - __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); - asm volatile("s_mov_b32 m0, %0; \n\t" - "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), - "v"(global_offset_bytes), - "s"(src_resource) - : "memory"); -#else - // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(lds_base_ptr + lds_offset)); - - llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); -#endif -} - -} // namespace ck_tile - -#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index bdcfbdd920..c2a093f1ab 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -5,11 +5,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" -#if __clang_major__ == 20 -#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" -#else #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#endif #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" From 41c17d0a953a5c399a3cf15ff283d1b57992f06d Mon Sep 17 00:00:00 2001 From: "BingYuan.Zhou" Date: Wed, 14 May 2025 09:31:26 +0800 Subject: [PATCH 119/443] fix moe sorting build fail (#2190) * fix moe sorting build fail * refile code --------- Co-authored-by: solin --- .../flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 3 ++- .../pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 2ff9d1ebf0..cbd20a6ea3 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -75,6 +75,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16) constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -90,7 +91,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; - // constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; +#endif #if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { ignore = i; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 474924ec84..1a1b729394 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -112,8 +112,8 @@ struct UniversalFlatmmPipelineAgBgCrPolicy make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; #endif - return a_lds_block_desc; } template From 7c0e29cc0f6f60ab66b48e324b2481d167722dd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 15 May 2025 16:21:34 +0200 Subject: [PATCH 120/443] Extend 64x64 with 4 waves instances for grouped conv bwd wei (#2187) * Extend 64x64 with 4 waves instnaces for grouped conv bwd wei * Fix * fix * fix --- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 29 ++++++++++++++++--- ...e_grouped_conv_bwd_weight_xdl_instance.hpp | 7 ++++- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 1c4dc8a445..0ed12b984b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -72,7 +72,14 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> + // clang-format on >; @@ -138,7 +145,13 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> // clang-format on >; @@ -218,7 +231,11 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, F16, F16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> // clang-format on >; @@ -275,7 +292,11 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index a493719637..3587570e42 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -87,7 +87,12 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple< DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, F32, F32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; From 3d8d6e75e485f5811df0ca37272f119392727726 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Thu, 15 May 2025 10:28:31 -0700 Subject: [PATCH 121/443] Adding validation for tile sizes in Tile Engine (#2189) * Adding validation for tile sizes * Add architecture in config, and shuffle lines of code in warp_gemm.hpp * Enable MFMA for gfx950, and invalid tile handling --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 26 ++--- .../warp/warp_gemm_attribute_mfma_impl.hpp | 8 +- .../gemm/configs/instance_combination.json | 4 +- tile_engine/ops/gemm/gemm_instance_builder.py | 96 +++++++++++++++---- 4 files changed, 96 insertions(+), 38 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f050a8e382..be5d5690ff 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, - 2>>; - -using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, - 2>>; - using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 69d22496f1..4bc4884beb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); @@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index 53197ada6c..b497513efa 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -1,5 +1,7 @@ { - + "architecture": { + "values": ["gfx90a"] + }, "layout_a": { "values": ["r"] }, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..dd8b4d1157 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float', } LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + + +warp_tile_combinations_map = { + "gfx90a": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32]], + 'bf8': [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + }, + "gfx950": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + } + +def sizeOf(data_type): + if data_type == 'fp16' or data_type == 'bf16': + return 2 + elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8': + return 1 + elif data_type == 'int4': ## TODO:: needs to confirm + return 0.5 + else: + return 4 DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< @@ -168,11 +200,15 @@ class GemmConfig: self.matrix_cfg : Dict[str, Any] = {} self.impl_cfg : Dict[str, Any] = {} for key, value in config_data.items(): - if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: self.matrix_cfg[key] = value else: self.impl_cfg[key] = value + @property + def architecture(self) -> str: + return self.matrix_cfg["architecture"]["values"][0] + @property def datatype(self) -> str: return self.matrix_cfg["datatype"]["values"][0] @@ -201,7 +237,7 @@ class GemmCodeGenerator: def _validate_config(self): """Validate matrix and implementation configurations""" # Matrix config validation - for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: if len(self.config.matrix_cfg[param]["values"]) != 1: raise ValueError(f"Matrix config {param} must have exactly one value") @@ -327,7 +363,7 @@ namespace {group_name} {{ return f""" template void try_run(ck_tile::TailNumber tn) {{ - if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN) - 1) {{ if (tn == TN) {{ RunSplitk(ck_tile::bool_constant{{}}, ck_tile::integral_constant{{}}); @@ -477,6 +513,30 @@ struct GemmKernel {{ content += f"#include \"gemm_{group}.hpp\"\n" (self.output_dir / "gemm_instances.hpp").write_text(content) + def is_tile_valid(self, tile: tuple, group: str) -> bool: + """Check if the tile configuration is valid for the given group""" + # Extract tile parameters + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile + + # Extract the pipeline and epilogue from the group name + _, pipeline, epilogue, scheduler, *_ = group.split("_") + + if tile_m % (warp_m * warp_tile_m) == 0 and \ + tile_n % (warp_n * warp_tile_n) == 0 and \ + tile_k % (warp_k * warp_tile_k) == 0: + total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype) + # Validate and append valid tile parameters + is_compv4 = pipeline == "compv4" + max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15) + + if total_tile_in_lds > max_tile_size: + raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. ' + f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB') + arch = self.config.architecture + if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]: + return True + return False + def _generate_dispatcher(self): """Generate dispatch mechanism""" content = """// SPDX-License-Identifier: MIT @@ -517,7 +577,7 @@ struct GemmDispatcher { self.config.impl_cfg["warp_tile_k"]["values"] )) - + for group in self.all_kernels: content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, @@ -526,26 +586,22 @@ struct GemmDispatcher { const ck_tile::stream_config& stream) {{ if(structured_sparsity){{ // SMFMA""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - sparse = self.atype == 'fp16' and \ - ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or - (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) - content += f""" + if self.is_tile_valid(tile, group): + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} else {{""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - content += f""" + if self.is_tile_valid(tile, group): + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} }};\n""" From 8cb0474b3d880abe55bca977856a4be104aac337 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 16 May 2025 02:47:29 +0800 Subject: [PATCH 122/443] Use only qr_async pipeline for batch_prefill (#2195) --- .../ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 30b9299963..76b9429b2e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -470,11 +470,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) From 791802b381c99e47966cbf4a987b91ab3d56bcfc Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Fri, 16 May 2025 15:14:46 +0800 Subject: [PATCH 123/443] [CK_TILE] fMHA batch_prefill block index & logits soft-capping optimizations (#2198) * Write soft-sign in inline asm * Change tile idx computation * Add macro to turn off soft-sign asm opt * Use simple for loop to avoid register spill * Only do block id transform for masking cases --- include/ck_tile/ops/fmha/block/variants.hpp | 38 ++++++++++++++++--- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 21 ++++++++-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 13 ++++++- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index 90fc5656fc..d8b0cdbb86 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -15,7 +15,36 @@ #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH #endif +#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM +#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 +#endif + namespace ck_tile { +namespace internal { +__device__ inline float +exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp) +{ +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + /// NOTICE: Make sure softmax_scale is stored in SGPR + float result, numerator, denominator; + asm volatile( + "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" + "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" + "v_rcp_f32_e32 %[denominator], %[denominator]\n" + "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" + "v_mul_f32_e32 %[result], %[numerator], %[denominator]" + : [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result) + : [softmax_scale] "s"(softmax_scale), + [logits] "v"(logits), + [logits_soft_cap_rcp] "v"(logits_soft_cap_rcp)); + return result; +#else + return softmax_scale * logits * rcp(1.f + abs(logits * logits_soft_cap_rcp)); +#endif +} +} // namespace internal template struct StandardAttentionParams @@ -169,8 +198,8 @@ struct LogitsSoftCap return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else @@ -239,9 +268,8 @@ struct ComposedAttention return params.logits_soft_cap * tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN - return params.sm_scale * type_convert(logits) * - rcp(1.f + - abs(type_convert(logits) * params.logits_soft_cap_rcp)); + return internal::exp2_soft_sign_impl( + params.sm_scale, type_convert(logits), params.logits_soft_cap_rcp); #endif } else diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index ba327ee511..7472c82114 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel }; const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index e07cf1c94e..8691622bb0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,8 +6,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -498,6 +499,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #else for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) { +#if(defined(__gfx90a__) || defined(__gfx94__)) && \ + (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ + CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif apply_logits_transform(s_acc.thread_buf_[i]); } #endif From fa3c6811d8e81096f52779bf0877777bf405d241 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Fri, 16 May 2025 10:18:47 +0200 Subject: [PATCH 124/443] Disable conv for Filter1x1Stride1Pad0 when K or C is even (#2186) --- include/ck/ck.hpp | 3 +++ .../device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 7 +++++++ .../test_grouped_convnd_bwd_weight.cpp | 1 + 3 files changed, 11 insertions(+) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index e38f166c1a..26e4787949 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -222,6 +222,9 @@ // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" #define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +// workaround: conv crash when K, C is even +#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 + // workaround: compiler crash when compiling recursive lambda #define CK_WORKAROUND_SWDEV_275126 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dd5b97096d..869457a99e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1206,6 +1206,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { +// workaround: disable when K, C is even +#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN + if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) + { + return false; + } +#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 21f2cb5ce6..95a0a09414 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -188,6 +188,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back({2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); From 40668c9a993ca9391eb628bbb4be3ca3fb4e7e56 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 16 May 2025 07:40:53 -0700 Subject: [PATCH 125/443] Build and store CK library deb package for all targets daily. (#2196) * generate and store library package for all targets * use ninja to build packages for all targets * make sure to use ftime-trace when using ninja * make sure build trace only runs on gfx9 * archive lib package and stash only library package --- Jenkinsfile | 135 +++++++++--------- .../gpu/CMakeLists.txt | 2 +- 2 files changed, 67 insertions(+), 70 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 68e0fa1246..c26350f120 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -93,6 +93,30 @@ def build_compiler(){ return compiler } +def check_arch(){ + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { + arch_type = 3 + } + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { + arch_type = 4 + } + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { + arch_type = 5 + } + else if ( runShell('grep -n "gfx908" rocminfo.log') ) { + arch_type = 6 + } + return arch_type +} + def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") @@ -287,7 +311,7 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){ echo "running ninja build trace" setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") @@ -315,7 +339,7 @@ def cmake_build(Map conf=[:]){ sh cmd //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" @@ -323,7 +347,15 @@ def cmake_build(Map conf=[:]){ archiveArtifacts "clang_build_analysis.log" // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ - sh "ninja test" + sh "ninja check" + } + if(params.BUILD_INSTANCES_ONLY){ + // build deb packages + echo "Build packages" + sh 'ninja -j64 package' + archiveArtifacts artifacts: 'composablekernel-dev*.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' + stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" } } else{ @@ -340,21 +372,14 @@ def cmake_build(Map conf=[:]){ archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } //check the node gpu architecture - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } + def arch = check_arch() if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_*.log" - if (arch_type == 1){ + if (arch == 1){ stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" } - else if (arch_type == 2){ + else if (arch == 2){ stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" } } @@ -379,10 +404,10 @@ def cmake_build(Map conf=[:]){ if (params.RUN_CK_TILE_GEMM_TESTS){ try{ archiveArtifacts "perf_tile_gemm_**.log" - if (arch_type == 1){ + if (arch == 1){ stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" } - else if (arch_type == 2){ + else if (arch == 2){ stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" } } @@ -410,7 +435,13 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts + if ( params.BUILD_INSTANCES_ONLY ){ + dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } + else{ + dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -521,28 +552,9 @@ def Build_CK(Map conf=[:]){ timeout(time: 20, unit: 'HOURS') { //check whether to run performance tests on this node - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } - else if ( runShell('grep -n "gfx10" rocminfo.log') ) { - arch_type = 3 - } - else if ( runShell('grep -n "gfx11" rocminfo.log') ) { - arch_type = 4 - } - else if ( runShell('grep -n "gfx12" rocminfo.log') ) { - arch_type = 5 - } - else if ( runShell('grep -n "gfx908" rocminfo.log') ) { - arch_type = 6 - } + def arch = check_arch() cmake_build(conf) - if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){ + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){ echo "Run inductor codegen tests" sh """ python3 -m venv ${env.WORKSPACE} @@ -553,9 +565,9 @@ def Build_CK(Map conf=[:]){ """ } dir("build"){ - if (params.RUN_FULL_QA && arch_type == 2 ){ - // build deb packages for all gfx9 targets on gfx90a system and prepare to export - echo "Build ckProfiler package" + if (params.RUN_FULL_QA && arch == 2 ){ + // build deb packages + echo "Build packages" sh 'make -j package' archiveArtifacts artifacts: 'composablekernel*.deb' sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' @@ -568,7 +580,7 @@ def Build_CK(Map conf=[:]){ // run performance tests, stash the logs, results will be processed on the master node dir("script"){ if (params.RUN_PERFORMANCE_TESTS){ - if (params.RUN_FULL_QA && arch_type == 1){ + if (params.RUN_FULL_QA && arch == 1){ // run full tests on gfx90a echo "Run full performance tests" sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -587,7 +599,7 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_mixed_gemm.log" stash includes: "perf_**.log", name: "perf_log" } - else if ( arch_type == 1 ){ + else if ( arch == 1 ){ // run standard tests on gfx90a echo "Run performance tests" sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -598,28 +610,28 @@ def Build_CK(Map conf=[:]){ stash includes: "perf_**.log", name: "perf_log" } // disable performance tests on gfx1030 for now. - //else if ( arch_type == 3){ + //else if ( arch == 3){ // run basic tests on gfx1030 // echo "Run gemm performance tests" // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" // archiveArtifacts "perf_onnx_gemm_gfx10.log" // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" //} - else if ( arch_type == 4){ + else if ( arch == 4){ // run basic tests on gfx11 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" archiveArtifacts "perf_onnx_gemm_gfx11.log" stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" } - else if ( arch_type == 5 ){ + else if ( arch == 5 ){ // run basic tests on gfx12 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" archiveArtifacts "perf_onnx_gemm_gfx12.log" stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" } - else if ( arch_type == 6 ){ + else if ( arch == 6 ){ // run basic tests on gfx908 echo "Run performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" @@ -628,7 +640,7 @@ def Build_CK(Map conf=[:]){ } } } - if (params.hipTensor_test && arch_type == 1 ){ + if (params.hipTensor_test && arch == 1 ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip @@ -730,24 +742,10 @@ def process_results(Map conf=[:]){ echo "could not locate the GEMM performance logs: ${err.getMessage()}." } } - if (params.RUN_FULL_QA){ - // unstash perf files to master + if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + // unstash deb packages unstash "packages" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - try{ - unstash "perf_log" - } - catch(Exception err){ - echo "could not locate perf_log: ${err.getMessage()}." - } - try{ - unstash "perf_log_gfx11" - unstash "perf_log_gfx12" - } - catch(Exception err){ - echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." - } - sh "./process_qa_data.sh" } else{ // unstash perf files to master @@ -775,12 +773,12 @@ def process_results(Map conf=[:]){ } } -//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true +//launch develop branch daily jobs +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" pipeline { @@ -1263,8 +1261,7 @@ pipeline { execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \ - -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ + -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 25ea3b2ae4..97946207a1 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -103,7 +103,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - message("remaining instances: ${ARGN}") + #message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) set(INST_OBJ) From 5b3430b868766068dabcc92394f0da65d9206099 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Fri, 16 May 2025 11:11:54 -0700 Subject: [PATCH 126/443] Narrowing error fix for codegen compilation (#2194) * removed comment with special characters * fix for arg/template change after merge from develop --------- Co-authored-by: Thomas Ning --- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 1 - .../device_gemm_multiple_d_xdl_cshuffle.hpp | 54 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6f3a7e6357..6f0404a1ca 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -381,7 +381,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KGroup, 1>{}([&](auto kg0) { - // K = k0 × KGroup × k1 = k0 × kg0 × A_K1 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, make_tuple(m0, I0, I0, Number{}, I0, I0), a_block_buf.At(I0), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 6c4195e75d..f193b093d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -860,35 +860,37 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - desc.cde_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, - desc.e_grid_desc_mblock_mperblock_nblock_nperblock, - desc.block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); } else { - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - desc.cde_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, - desc.e_grid_desc_mblock_mperblock_nblock_nperblock, - desc.block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); } } }; From 6342f6b5e8bbb9f2b4cefa33d2a863a8bb35329b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 17 May 2025 03:42:02 +0200 Subject: [PATCH 127/443] Restore oddc instances (#2201) --- .../gpu/grouped_convolution_forward.hpp | 8 ++ .../gpu/grouped_convolution_forward_wmma.inc | 111 ++++++++++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 4 + ...ma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp | 40 +++++++ ...mma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp | 40 +++++++ ...ma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp | 40 +++++++ ...mma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp | 40 +++++++ ...hwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp | 9 ++ ...l_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp | 8 ++ ...c_gkyxc_nhwgk_bf16_comp_part2_instance.cpp | 9 ++ ...nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp | 9 ++ ...dl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp | 8 ++ ...gc_gkyxc_nhwgk_f16_comp_part2_instance.cpp | 9 ++ ...dl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp | 10 +- ...l_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp | 28 ++++- ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 10 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 10 +- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 10 +- ...wd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp | 10 +- ...gc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp | 11 +- ...wgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 11 +- ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 11 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 4 + ...gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp | 41 +++++++ ..._gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp | 41 +++++++ ...ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp | 41 +++++++ ..._ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp | 41 +++++++ 35 files changed, 682 insertions(+), 17 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index cf5dbaa323..545826650c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -613,6 +613,7 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -236,6 +291,20 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index eba6fd789e..22e9d726b0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -93,6 +93,8 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp ## NHWGC, GKYXC, NHWGK wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -100,4 +102,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..a8f723dfec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..784a118897 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..8c621543a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..5cb313b3ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp index f5df7278d0..c078f8ed04 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index db048679bd..a67b11f1cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp index ee9507a80a..5c0391a25f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instanc Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp index 132d3c8411..726276c461 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index a7deb969ba..8b7bdec2a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp index d2732547fa..c66114b9a3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp @@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp index 8a0caebc9f..93e07e08fb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -48,6 +48,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp index e45df1e107..6acbb7475c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -50,6 +50,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( NHWGK, ConvFwd1x1S1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); + if(ck::get_device_name() != "gfx950") { add_device_operation_instances( @@ -78,6 +86,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } if(ck::get_device_name() == "gfx950") @@ -108,6 +125,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 078221f89f..2afbfdc386 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 3a481dd204..822ef51e00 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 5add0f8add..79a1fb99a8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 0257c7d315..e567c0df75 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 2715506fe2..3e42184996 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 8d3e4d91b1..c035d4c3da 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp index 465fa927a5..5c425effd8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp index 87423801cb..e8a763c527 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp index ebb213461a..3ae3fb5186 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp index c2c8a099b2..cb7e912936 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp index 11cb853f0d..d787f4b048 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp index 1992d7f7c1..5644289790 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp index 2b8fd3d9db..5b12dad5a3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp index 5579ec62cc..f667481fa4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp index 77f3df2c11..2ff2c7f51f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index f55bdd45c9..f8efa5a7c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -66,6 +66,10 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..fa378af1ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..d41416fd4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp new file mode 100644 index 0000000000..8a7bc26178 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp new file mode 100644 index 0000000000..7649f86971 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From b8b12bb81e1b370d39ab7b17b0c13654a6e54721 Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Mon, 19 May 2025 14:25:50 -0700 Subject: [PATCH 128/443] Fix example_grouped_gemm_multiple_d_xdl_fp16 on gfx950 (#2203) * Fix example_grouped_gemm_multiple_d_xdl_fp16 on gfx950 * Run clang format --- example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index db162fe444..63a2aea0b3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); - d_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size std::size_t flop = 0, num_btype = 0; From 57e0f5df29abefd919c334c994628a994ba2868c Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 19 May 2025 15:52:51 -0600 Subject: [PATCH 129/443] MX GEMM - Expand MX MFMA Testing to BF8, FP6, and BF6 Data Types (#2199) * Unify test interface for different layouts. * WIP: Introducing FP4/FP6/FP8 abstractions * WIP: Introducing packed storage abstraction * WIP: Introducing packed storage abstraction * WIP: Improved support for FP6 data type * Refactor packed storage for f6_t * WIP: FP6 MFMA test * Test if we correctly represent all FP6/FP4 numbers * Additional output for failed FP4 test. * More failing conversion tests * Even more failing conversion tests * Working FP6 MFMA tests * Expand MX MFMA testing to BF8/6 * Update and verify MX MFMA test for packed types * Fix fp4 and fp6 conversions on host * Working MX MFMA tests for FP8/6/4 * Cleanup * Add missing type * Cleanup * Final cleanup * Restrict FP6/4 values output to CK_LOGGING=1 * Use CHAR_BIT instead of number 8 * Fix typo * Remove FP6 and FP4 from the list of native types --------- Co-authored-by: Rostyslav Geyyer --- include/ck/library/utility/host_tensor.hpp | 57 +-- .../library/utility/host_tensor_generator.hpp | 232 ++++++++++ include/ck/utility/amd_xdlops.hpp | 390 ++++++++++++++-- include/ck/utility/data_type.hpp | 428 +++++++----------- include/ck/utility/dtype_vector.hpp | 104 ++++- include/ck/utility/mxf4_utils.hpp | 12 +- include/ck/utility/mxf6_utils.hpp | 8 +- .../cpu/reference_gemm.hpp | 16 + .../cpu/reference_mx_gemm.hpp | 20 + test/data_type/test_bf6.cpp | 111 ++++- test/data_type/test_fp4.cpp | 57 +++ test/data_type/test_fp6.cpp | 106 ++++- test/mx_mfma_op/mx_mfma_op.cpp | 365 ++++++++++++--- test/mx_mfma_op/mx_mfma_op.hpp | 282 ++++++------ 14 files changed, 1601 insertions(+), 587 deletions(-) diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 71417ce7bf..257636d956 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -360,10 +360,9 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) + if constexpr(ck::is_packed_type_v>) { - return (mDesc.GetElementSpaceSize() + 1) / 2; + return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v>; } else { @@ -516,69 +515,31 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mDesc.GetOffsetFromMultiIndex(is...) / 2; - } - else - { - return mDesc.GetOffsetFromMultiIndex(is...); - } + return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v>; } template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(is...)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; } T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - { - return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; - } - else - { - return mData[mDesc.GetOffsetFromMultiIndex(idx)]; - } + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; } typename Data::iterator begin() { return mData.begin(); } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 785f74a3c0..f48ba49bbf 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -67,6 +67,18 @@ struct GeneratorTensor_1 return ck::type_convert(value); } }; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf8_t operator()(Is...) + { + return ck::type_convert(value); + } +}; #endif template <> @@ -93,6 +105,38 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + template <> struct GeneratorTensor_1 { @@ -132,6 +176,44 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + template <> struct GeneratorTensor_2 { @@ -342,6 +424,46 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + template struct GeneratorTensor_4 { @@ -360,6 +482,69 @@ struct GeneratorTensor_4 } }; +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f4x2_pk_t operator()(Is...) + { + float fp32_tmp0 = distribution(generator); + float fp32_tmp1 = distribution(generator); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + struct GeneratorTensor_Checkboard { template @@ -405,6 +590,53 @@ struct GeneratorTensor_Sequential } }; +template +struct GeneratorTensor_Sequential +{ + template + ck::f4x2_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + return ck::type_convert(ck::float2_t(tmp)); + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::f6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::bf6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + template struct GeneratorTensor_Diagonal { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 66c4958e1d..ad48389625 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -498,7 +498,7 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -511,6 +511,28 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -536,6 +558,62 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -583,6 +661,43 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf8x32_t& reg_a, const int32_t& scale_a, @@ -620,6 +735,74 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -639,7 +822,7 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, reg_c.template AsType()[Number<0>{}], - 4, // cbsz + 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 4, // blgp 0, // OPSEL scale_a, @@ -748,6 +931,101 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const int32_t scale_a, @@ -778,35 +1056,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> ignore = reg_b; ignore = scale_b; ignore = reg_c; -#endif - } - - template - __device__ static void Run(const bf8x32_t& reg_a, - const int32_t& scale_a, - const f8x32_t& reg_b, - const int32_t& scale_b, - FloatC& reg_c) - { -#if defined(__gfx950__) - // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 - reg_c.template AsType()(Number<0>{}) = - __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( - reg_a, - reg_b, - reg_c.template AsType()[Number<0>{}], - 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} - 0, // blgp - 0, // OPSEL - scale_a, - 0, // OPSEL - scale_b); -#else - ignore = reg_a; - ignore = scale_a; - ignore = reg_b; - ignore = scale_b; - ignore = reg_c; #endif } }; @@ -833,7 +1082,7 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, 0, @@ -846,6 +1095,29 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -870,6 +1142,60 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a6106bb146..c11b9c0272 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -32,8 +32,14 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format +// scalar_type +template +struct scalar_type; + struct f4x2_pk_t { + static constexpr int packed_size = 2; + using type = uint8_t; type data; __host__ __device__ f4x2_pk_t() : data{type{}} {} @@ -55,269 +61,82 @@ struct f4x2_pk_t } }; -struct f6x16_pk_t +template +struct f6_pk_t { - // store 16 elements of f6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - f6x16_pk_t() : data{type{}} {} - f6x16_pk_t(type init) : data{init} {} + using element_type = uint32_t; // element storage fundamental type - template - __host__ __device__ inline f6_t unpack(Number) + static constexpr index_t packed_size = pk_size; + static constexpr index_t num_bits_elem = 6; + static constexpr index_t num_bits_vec_elem = sizeof(element_type) * CHAR_BIT; + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; + + using storage_type = StaticallyIndexedArray_v2; + storage_type data; // packed data + + using type = f6_pk_t; + + __host__ __device__ constexpr f6_pk_t() : data{} {} + __host__ __device__ constexpr f6_pk_t(storage_type init) : data{init} {} + template ::vector_size == packed_size>> + __host__ __device__ f6_pk_t(const T& v) : data{} { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); + static_for<0, packed_size, 1>{}( + [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); + } - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + template + __host__ __device__ void pack(const T x, const index_t i) + { + static_assert(is_integral::value || is_same_v, + "T must be an integral type."); - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) + uint32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = data.data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data.data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) + uint32_t next_value = data.data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data.data_[arr_index + 1] = next_value; + } + } + + __host__ __device__ static inline BitType unpack(const type& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + uint32_t bits = pk.data.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); } - return static_cast(bits & 0x3F); + return static_cast(bits & 0x3F); } - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } + __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); } }; -struct f6x32_pk_t -{ - // store 32 elements of f6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - f6x32_pk_t() : data{type{}} {} - f6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline f6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 f6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x16_pk_t -{ - // store 16 elements of bf6_t in an array of 3 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(16))); - bf6x16_pk_t() : data{type{}} {} - bf6x16_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 16, "Index out of range for 16 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 16 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 16, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 3; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; - -struct bf6x32_pk_t -{ - // store 32 elements of bf6_t in an array of 6 uint32_t - using element_type = uint32_t; - using type = StaticallyIndexedArray_v2; - type data; - typedef int8_t test_vec_t __attribute__((ext_vector_type(32))); - bf6x32_pk_t() : data{type{}} {} - bf6x32_pk_t(type init) : data{init} {} - - template - __host__ __device__ inline bf6_t unpack(Number) - { - static_assert(I < 32, "Index out of range for 32 f6_t elements."); - - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = I * num_bits_elem; - constexpr int arr_idx = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - uint32_t bits = data.At(Number{}) >> bit_offset; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - - if constexpr(overhang > 0 && (arr_idx + 1) < vector_size) - { - bits |= (data.At(Number{}) & ((1u << overhang) - 1)) - << (num_bits_elem - overhang); - } - - return static_cast(bits & 0x3F); - } - - __host__ __device__ inline type pack(const test_vec_t& x) - { - type packed{}; - - // for each of the 32 bf6_t values, place its 6 bits in the correct position - ck::static_for<0, 32, 1>{}([&](auto i) { - uint32_t bits = static_cast(x[static_cast(i)]) & 0x3F; - constexpr int num_bits_elem = 6; - constexpr int num_bits_vec_elem = 32; - constexpr int vector_size = 6; - constexpr int bit_pos = i * num_bits_elem; - constexpr int arr_index = bit_pos / num_bits_vec_elem; - constexpr int bit_offset = bit_pos % num_bits_vec_elem; - constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; - uint32_t old_value = packed.At(Number{}); - - // insert bits into the current 32-bit block - old_value |= (bits << bit_offset); - packed.At(Number{}) = old_value; - - // if it crosses into the next block, shift the remainder - if constexpr(overhang > 0 && (arr_index + 1) < vector_size) - { - uint32_t next_value = packed.At(Number{}); - next_value |= (bits >> (num_bits_elem - overhang)); - packed.At(Number{}) = next_value; - } - }); - - return packed; - } -}; +using f6x16_pk_t = f6_pk_t; +using f6x32_pk_t = f6_pk_t; +using bf6x16_pk_t = f6_pk_t; +using bf6x32_pk_t = f6_pk_t; // custom data type - pack int4 data struct pk_i4_t @@ -335,15 +154,14 @@ inline constexpr auto next_pow2(uint32_t x) } // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool, f4_t, f6_t, bf6_t +// native types: bool template inline constexpr bool is_native_type() { return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value; + is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value; } // scalar_type @@ -484,6 +302,106 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +// Default behavior for types that do not need special handling +template +struct packed_type +{ + using type = T; + static constexpr index_t packed_size = 1; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = pk_i4_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f4x2_pk_t; + static constexpr index_t packed_size = 2; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = f6x32_pk_t; + static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements +}; + +template <> +struct packed_type +{ + using type = bf6x32_pk_t; + static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements +}; + +template +using packed_type_t = typename packed_type::type; + +// Check if the type has packed type specialization +template +inline constexpr bool has_packed_type_v = !is_same_v, T>; + +template +struct element_type +{ + private: + static constexpr auto get_element_type() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + return int4_t{}; + else if constexpr(is_same_v) + return f4_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else if constexpr(is_same_v) + return f6_t{}; + else if constexpr(is_same_v) + return bf6_t{}; + else + return T{}; + } + + public: + using type = decltype(get_element_type()); +}; +template +using element_type_t = typename element_type::type; + +template +inline constexpr bool is_packed_type_v = + has_packed_type_v>&& is_same_v>>; + +template +struct packed_size +{ + private: + static constexpr auto get_packed_size() + { + using U = remove_cvref_t; + if constexpr(is_packed_type_v) + return Number>::packed_size>{}; + else + return Number::packed_size>{}; + } + + public: + using type = decltype(get_packed_size()); + static constexpr auto value = get_packed_size(); +}; + +template +using packed_size_t = typename packed_size::type; + +template +inline constexpr index_t packed_size_v = packed_size::value; + #if defined(_WIN32) using int64_t = long long; #else diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 9c40d923d3..65eed0624c 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -365,6 +365,88 @@ struct vector_type()>> } }; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + typedef T d6_t __attribute__((ext_vector_type(6))); + + using type = d6_t; + + union + { + d6_t d6_; + StaticallyIndexedArray d1x6_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d3x2_; + StaticallyIndexedArray d6x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } +}; + template struct vector_type()>> { @@ -1221,25 +1303,25 @@ struct nnvb_data_t_selector template <> struct nnvb_data_t_selector { - using type = f6x16_pk_t::type; + using type = f6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = f6x32_pk_t::type; + using type = f6x32_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x16_pk_t::type; + using type = bf6x16_pk_t::storage_type; }; template <> struct nnvb_data_t_selector { - using type = bf6x32_pk_t::type; + using type = bf6x32_pk_t::storage_type; }; template <> @@ -1406,12 +1488,23 @@ struct non_native_vector_base -struct scalar_type> +struct scalar_type>> { using type = typename non_native_vector_base::data_t; static constexpr index_t vector_size = N; }; +template +struct scalar_type< + non_native_vector_base>> +{ + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; + // non-native vector_type implementation template struct vector_type()>> @@ -2025,6 +2118,7 @@ using bhalf32_t = typename vector_type::type; // i32 using int32x2_t = typename vector_type::type; using int32x4_t = typename vector_type::type; +using int32x6_t = typename vector_type::type; using int32x8_t = typename vector_type::type; using int32x16_t = typename vector_type::type; using int32x32_t = typename vector_type::type; diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index b0b5297f77..53edb6e182 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -66,7 +66,7 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -74,8 +74,8 @@ __host__ __device__ inline f4_t sat_convert_to_type(float value) if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } @@ -91,7 +91,7 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -99,8 +99,8 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 if(std::abs(to_float(NumericLimits::Binary_1(), res)) < NumericLimits::DataMinSubnorm()) - return value < 0 ? NumericUtils::negative_zero_mask - : NumericUtils::positive_zero_mask; + return sign ? NumericUtils::negative_zero_mask + : NumericUtils::positive_zero_mask; return res; } diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index cf68188b3e..a840c520a9 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -201,7 +201,7 @@ __host__ __device__ inline f6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -239,7 +239,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type(float value) : NumericUtils::data_max_positive_normal_mask; } - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -274,7 +274,7 @@ __host__ __device__ inline f6_t sat_convert_to_type_sr(float value, uint32 return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -308,7 +308,7 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well + if(std::abs(value) > NumericLimits::DataMaxNorm()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index c8d284a1d7..ed07e53e6d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -89,6 +89,14 @@ struct ReferenceGemm : public device::BaseOperator v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_a = type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -115,6 +123,14 @@ struct ReferenceGemm : public device::BaseOperator v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + v_b = type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index e8fdcf1acd..3fc39911dd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -105,6 +105,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + a_m_k_scaled(m, k) = + type_convert( + arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } else { a_m_k_scaled(m, k) = @@ -134,6 +144,16 @@ struct ReferenceMXGemm : public device::BaseOperator type_convert( arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); } + else if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + b_k_n_scaled(k, n) = + type_convert( + arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } else { b_k_n_scaled(k, n) = diff --git a/test/data_type/test_bf6.cpp b/test/data_type/test_bf6.cpp index a260f81d16..9dbb77454c 100644 --- a/test/data_type/test_bf6.cpp +++ b/test/data_type/test_bf6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::bf6_convert_rne; @@ -41,6 +42,11 @@ TEST(BF6, ConvertFP32Nearest) ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(std::numeric_limits::infinity())), 0.0f); + + // convert float +/-30 to bf6 and back, check if clipped to +/-max_bf6 + ASSERT_NEAR(-max_bf6, type_convert(bf6_convert_rne(-30.0f)), 0.0f); + ASSERT_NEAR(max_bf6, type_convert(bf6_convert_rne(30.0f)), 0.0f); + // convert float value less than bf6 subnorm to bf6 and back, check if equal to 0.0 float less_than_subnorm = 0.03125f; ASSERT_NEAR(0.0f, type_convert(bf6_convert_rne(less_than_subnorm)), 0.0f); @@ -266,21 +272,18 @@ TEST(BF6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); }); } @@ -329,23 +332,23 @@ TEST(BF6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = bf6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -369,20 +372,86 @@ TEST(BF6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = bf6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = bf6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(BF6, TestAllValues) +{ + + constexpr std::array e3m2ValuesOCP = { + // clang-format off + 0.0000000000, 0.0625000000, 0.1250000000, 0.1875000000, + 0.2500000000, 0.3125000000, 0.3750000000, 0.4375000000, + 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.2500000000, 1.5000000000, 1.7500000000, + 2.0000000000, 2.5000000000, 3.0000000000, 3.5000000000, + 4.0000000000, 5.0000000000, 6.0000000000, 7.0000000000, + 8.0000000000, 10.0000000000, 12.0000000000, 14.0000000000, + 16.0000000000, 20.0000000000, 24.0000000000, 28.0000000000, + -0.0000000000, -0.0625000000, -0.1250000000, -0.1875000000, + -0.2500000000, -0.3125000000, -0.3750000000, -0.4375000000, + -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.2500000000, -1.5000000000, -1.7500000000, + -2.0000000000, -2.5000000000, -3.0000000000, -3.5000000000, + -4.0000000000, -5.0000000000, -6.0000000000, -7.0000000000, + -8.0000000000, -10.0000000000, -12.0000000000, -14.0000000000, + -16.0000000000, -20.0000000000, -24.0000000000, -28.0000000000 + // clang-format on + }; + + constexpr uint8_t e3m2BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("BF6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(bf6_t(e3m2BitsOCP[i])); + ASSERT_EQ(fp, e3m2ValuesOCP[i]); + + bf6_t bf6 = type_convert(e3m2ValuesOCP[i]); + ASSERT_EQ(bf6 & 0x3F, e3m2BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e3m2BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e3m2BitsOCP[i], e3m2ValuesOCP[i]); + } }); } diff --git a/test/data_type/test_fp4.cpp b/test/data_type/test_fp4.cpp index f4b2bf3358..3fc74a2ef3 100644 --- a/test/data_type/test_fp4.cpp +++ b/test/data_type/test_fp4.cpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" #include "ck/utility/scaled_type_convert.hpp" +#include "ck/utility/env.hpp" using ck::e8m0_bexp_t; using ck::f4_convert_rne; @@ -38,6 +39,11 @@ TEST(FP4, ConvertFP32Nearest) // convert maximal float to fp4 and back, check if clipped to 6.0 ASSERT_NEAR( max_fp4, type_convert(f4_convert_rne(std::numeric_limits::max())), abs_tol); + + // convert +/-7.0 to fp4 and back, check if clipped to +/-6.0 + ASSERT_NEAR(-max_fp4, type_convert(f4_convert_rne(-7.0f)), 0.0); + ASSERT_NEAR(max_fp4, type_convert(f4_convert_rne(7.0f)), 0.0); + // positive norm float value to fp4 and back, check if holds float pos_float = 1.0f; ASSERT_NEAR(pos_float, type_convert(f4_convert_rne(pos_float)), abs_tol); @@ -468,3 +474,54 @@ TEST(FP4, TestAsType32) test_vec.at(i + 1)); }); } + +TEST(FP4, TestAllValues) +{ + constexpr std::array e2m1ValuesOCP = { + // clang-format off + 0.0000000000, 0.5000000000, + 1.0000000000, 1.5000000000, + 2.0000000000, 3.0000000000, + 4.0000000000, 6.0000000000, + -0.0000000000, -0.5000000000, + -1.0000000000, -1.5000000000, + -2.0000000000, -3.0000000000, + -4.0000000000, -6.0000000000 + // clang-format on + }; + + constexpr uint8_t e2m1BitsOCP[] = { + // clang-format off + 0b0000, 0b0001, + 0b0010, 0b0011, + 0b0100, 0b0101, + 0b0110, 0b0111, + 0b1000, 0b1001, + 0b1010, 0b1011, + 0b1100, 0b1101, + 0b1110, 0b1111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP4 Table\n"); + ck::static_for<0, 16, 1>{}([&](auto i) { + float fp = type_convert(f4_t(e2m1BitsOCP[i])); + ASSERT_EQ(fp, e2m1ValuesOCP[i]); + + f4_t fp4 = type_convert(e2m1ValuesOCP[i]); + ASSERT_EQ(fp4 & 0xF, e2m1BitsOCP[i] & 0xF); + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 3; j >= 0; --j) + { + printf("%c", (e2m1BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m1BitsOCP[i], e2m1ValuesOCP[i]); + } + }); +} diff --git a/test/data_type/test_fp6.cpp b/test/data_type/test_fp6.cpp index cf91e69db3..6d4aec1d9a 100644 --- a/test/data_type/test_fp6.cpp +++ b/test/data_type/test_fp6.cpp @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scaled_type_convert.hpp" using ck::e8m0_bexp_t; @@ -34,6 +35,11 @@ TEST(FP6, ConvertFP32Nearest) ASSERT_NEAR(0.0f, type_convert(f6_convert_rne(0.0f)), 0.0f); // convert maximal f6_t to float and check if equal to max_fp6 ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(max_fp6)), 0.0f); + + // convert maximal +/-8.0 to fp6 and check if equal to +/-max_fp6 + ASSERT_NEAR(-max_fp6, type_convert(f6_convert_rne(-8.0f)), 0.0f); + ASSERT_NEAR(max_fp6, type_convert(f6_convert_rne(8.0f)), 0.0f); + // convert maximal float to fp6 and back, check if clipped to max_fp6 ASSERT_NEAR( max_fp6, type_convert(f6_convert_rne(std::numeric_limits::max())), 0.0f); @@ -265,20 +271,24 @@ TEST(FP6, TestAsType16x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec}; }); + // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])) + << " i = " << i << "; left = " + << type_convert(left_vec.template AsType()(Number<0>{}).unpack(i)) + << " -- right = " + << type_convert(static_cast(test_vec[static_cast(i)])) << " (" + << static_cast(test_vec[static_cast(i)]) << ")" << std::endl; }); } @@ -327,23 +337,23 @@ TEST(FP6, TestAsType16x2) // check default CTOR ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(right_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - 0); + ASSERT_EQ( + right_vec.template AsType()(Number{}).unpack(idx_element), + 0); }); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x16_pk_t{}.pack(test_vec[i]); + right_vec.template AsType()(Number{}) = f6x16_pk_t{test_vec[i]}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, vector_size, 1>{}([&](auto idx_vector) { ck::static_for<0, packed_size, 1>{}([&](auto idx_element) { - ASSERT_EQ(left_vec.template AsType()(Number{}) - .template unpack<>(Number{}), - static_cast(test_vec[idx_vector][static_cast(idx_element)])); + ASSERT_EQ( + left_vec.template AsType()(Number{}).unpack(idx_element), + static_cast(test_vec[idx_vector][static_cast(idx_element)])); }); }); } @@ -367,19 +377,77 @@ TEST(FP6, TestAsType32x1) vector_type right_vec; // check default CTOR ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - right_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), 0); + ASSERT_EQ(right_vec.template AsType()(Number<0>{}).unpack(i), 0); }); // assign test values to the vector ck::static_for<0, vector_size, 1>{}([&](auto i) { - right_vec.template AsType()(Number{}) = f6x32_pk_t{}.pack(test_vec); + right_vec.template AsType()(Number{}) = f6x32_pk_t{test_vec}; }); // copy the vector vector_type left_vec{right_vec}; // check if values were copied correctly ck::static_for<0, packed_size, 1>{}([&](auto i) { - ASSERT_EQ( - left_vec.template AsType()(Number<0>{}).template unpack<>(Number{}), - static_cast(test_vec[static_cast(i)])); + ASSERT_EQ(left_vec.template AsType()(Number<0>{}).unpack(i), + static_cast(test_vec[static_cast(i)])); + }); +} + +TEST(FP6, TestAllValues) +{ + constexpr std::array e2m3ValuesOCP = { + // clang-format off + 0.0000000000, 0.1250000000, 0.2500000000, 0.3750000000, 0.5000000000, 0.6250000000, 0.7500000000, 0.8750000000, + 1.0000000000, 1.1250000000, 1.2500000000, 1.3750000000, 1.5000000000, 1.6250000000, 1.7500000000, 1.8750000000, + 2.0000000000, 2.2500000000, 2.5000000000, 2.7500000000, 3.0000000000, 3.2500000000, 3.5000000000, 3.7500000000, + 4.0000000000, 4.5000000000, 5.0000000000, 5.5000000000, 6.0000000000, 6.5000000000, 7.0000000000, 7.5000000000, + -0.0000000000, -0.1250000000, -0.2500000000, -0.3750000000, -0.5000000000, -0.6250000000, -0.7500000000, -0.8750000000, + -1.0000000000, -1.1250000000, -1.2500000000, -1.3750000000, -1.5000000000, -1.6250000000, -1.7500000000, -1.8750000000, + -2.0000000000, -2.2500000000, -2.5000000000, -2.7500000000, -3.0000000000, -3.2500000000, -3.5000000000, -3.7500000000, + -4.0000000000, -4.5000000000, -5.0000000000, -5.5000000000, -6.0000000000, -6.5000000000, -7.0000000000, -7.5000000000 + // clang-format on + }; + + constexpr uint8_t e2m3BitsOCP[] = { + // clang-format off + 0b000000, 0b000001, 0b000010, 0b000011, + 0b000100, 0b000101, 0b000110, 0b000111, + 0b001000, 0b001001, 0b001010, 0b001011, + 0b001100, 0b001101, 0b001110, 0b001111, + 0b010000, 0b010001, 0b010010, 0b010011, + 0b010100, 0b010101, 0b010110, 0b010111, + 0b011000, 0b011001, 0b011010, 0b011011, + 0b011100, 0b011101, 0b011110, 0b011111, + 0b100000, 0b100001, 0b100010, 0b100011, + 0b100100, 0b100101, 0b100110, 0b100111, + 0b101000, 0b101001, 0b101010, 0b101011, + 0b101100, 0b101101, 0b101110, 0b101111, + 0b110000, 0b110001, 0b110010, 0b110011, + 0b110100, 0b110101, 0b110110, 0b110111, + 0b111000, 0b111001, 0b111010, 0b111011, + 0b111100, 0b111101, 0b111110, 0b111111 + // clang-format on + }; + + const bool ck_logging = ck::EnvIsEnabled(CK_ENV(CK_LOGGING)); + + if(ck_logging) + printf("FP6 Table\n"); + ck::static_for<0, 64, 1>{}([&](auto i) { + float fp = type_convert(f6_t(e2m3BitsOCP[i])); + ASSERT_EQ(fp, e2m3ValuesOCP[i]); + + f6_t fp6 = type_convert(e2m3ValuesOCP[i]); + ASSERT_EQ(fp6 & 0x3F, e2m3BitsOCP[i] & 0x3F); + + if(ck_logging) + { + // Print the binary representation + printf("Bits: 0b"); + for(int j = 5; j >= 0; --j) + { + printf("%c", (e2m3BitsOCP[i] & (1 << j)) ? '1' : '0'); + } + printf(", 0x%02X, Value: %f\n", e2m3BitsOCP[i], e2m3ValuesOCP[i]); + } }); } diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index fddb8288a6..5e2aedd35e 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -5,9 +5,12 @@ #include "mx_mfma_op.hpp" +using ck::bf6_t; +using ck::bf8_t; using ck::e8m0_bexp_t; using ck::f4_t; using ck::f4x2_pk_t; +using ck::f6_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -17,13 +20,15 @@ using ck::type_convert; * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mfma_km_kn_nm_test(ck::index_t init) +template +bool run_mfma_test(ck::index_t init) { - using ALayout = ck::tensor_layout::gemm::ColumnMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::ColumnMajor; - using AccType = float; // only MFMA_F32 instructions supported using CPUAccType = AccType; @@ -53,74 +58,153 @@ bool run_mfma_km_kn_nm_test(ck::index_t init) return pass; } +const ck::index_t common_init = -4; // set to "< 0" for test-specific initializations + TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -TEST(MFMA, FP8MFMA32x32x64) +TEST(MFMA, BF8MFMA16x16x128) { - auto AB_init = 5; - auto pass = run_mfma_km_kn_nm_test(AB_init); + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } -/** - * @brief Run the test for the given MFMA instruction - * - * @param init - selects initialization algorithm for A and B tensors - */ -template -bool run_mfma_mk_kn_mn_test(ck::index_t init) +TEST(MFMA, FP4MFMA16x16x128) { using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; - using AccType = float; // only MFMA_F32 instructions supported - using CPUAccType = AccType; - - ck::mfma_type(mfma)> mfma_instr; - constexpr auto BLOCK_M = mfma_instr.m_per_blk; - constexpr auto BLOCK_N = mfma_instr.n_per_blk; - constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - - const auto mfma_kernel = ck:: - matmul; - - bool pass = true; - - pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); - - return pass; + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); } -TEST(MFMA, FP4MFMA16x16x128) +TEST(MFMA, FP6MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::ColumnMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::ColumnMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = run_mfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = + run_mfma_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, BF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } @@ -129,15 +213,18 @@ TEST(MFMA, FP4MFMA32x32x64) * * @param init - selects initialization algorithm for A and B tensors */ -template -bool run_mxmfma_mk_kn_mn_test(ck::index_t init) +template +bool run_mxmfma_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); - using ALayout = ck::tensor_layout::gemm::RowMajor; - using BLayout = ck::tensor_layout::gemm::ColumnMajor; - using CLayout = ck::tensor_layout::gemm::RowMajor; using AccType = float; // only MFMA_F32 instructions supported using ScaleType = ck::e8m0_bexp_t; // biased exponent type @@ -181,34 +268,170 @@ bool run_mxmfma_mk_kn_mn_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 5; - auto pass = - run_mxmfma_mk_kn_mn_test(AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF8MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA16x16x128) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXBF6MFMA32x32x64) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA16x16x128) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP4MFMA32x32x64) { - auto AB_init = 4; - auto pass = - run_mxmfma_mk_kn_mn_test( - AB_init); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + auto AB_init = (common_init < 0) ? 5 : common_init; + auto pass = run_mxmfma_test(AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 9ce871cfb1..4cab411cb4 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -151,6 +151,8 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on + static_assert(!is_packed_type_v, "Packed type is not supported"); + static constexpr int32_t WAVE_SIZE = 64; // Here we want to load from rows of A in chunks of 16 elements each. @@ -270,12 +272,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from rows of A in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; @@ -283,43 +301,35 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} - (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + // FP8/6/4 Row {0-31} | {0-15} + // FP8 Col {0, 16} | {0, 16, 32, 48} + // FP6/4 Col {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, (threadIdx.x / BLOCK_M) * chunk_size); - // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; - // BLOCK_K is a stride in A matrix - auto startOffset = row_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - row_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using ARawT = typename scalar_type::type; + using AScalarChunkT = vector_type::vector_size / num_chunks>::type; union { AFragT frag; - AScalarFragT chunks[num_chunks]; + AScalarChunkT chunks[num_chunks]; } fragA{}; - const AScalarFragT* fragPtr; + const AScalarChunkT* fragPtr; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K) / packed_size_v; for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = reinterpret_cast(input_ptr + startOffset + - chunk_idx * kMajorOffset); + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); fragA.chunks[chunk_idx] = *fragPtr; } @@ -488,12 +498,27 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 16x128 for FP6: || Register Mapping for 32x64 for FP6: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | K = 64-79 | K = 96-111 | v[0] || Reg 0-2 [0:95] | K = 0-15 | K = 32-47 | v[0] | + // Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | K = 80-95 | K = 112-127 | v[0] || Reg 3-5 [0:95] | K = 16-31 | K = 48-63 | v[0] | + // clang-format on static constexpr int32_t WAVE_SIZE = 64; + // FP8 chunk_size = 16, num_chunks = 2, packed_size = 1 + // FP4 chunk_size = 32, num_chunks = 1, packed_size = 2 + // FP6 chunk_size = 32, num_chunks = 1, packed_size = 32 + + constexpr index_t num_chunks = is_packed_type_v ? 1 : 2; + // Here we want to load from cols of B in chunks of 16 elements each. - static constexpr uint32_t chunk_size = 16; + constexpr uint32_t chunk_size = is_packed_type_v ? 32 : 16; // each chunk is separated by an offset static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 @@ -501,44 +526,36 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = - std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} - threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} + // FP8/6/4 Col {0-31} | {0-15} + // FP8 Row {0, 16} | {0, 16, 32, 48} + // FP6/4 Row {0, 32} | {0, 32, 64, 96} + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, threadIdx.x % BLOCK_N); // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - // BLOCK_K is a stride in B matrix - auto startOffset = col_major( - startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / - // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - auto kMajorOffset = - col_major(majorStepCoord2D, - BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); - - using BRawT = typename scalar_type::type; - using BScalarFragT = vector_type::type; - - constexpr index_t num_chunks = - (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + using BRawT = typename scalar_type::type; + using BScalarChunkT = vector_type::vector_size / num_chunks>::type; union { BFragT frag; - BScalarFragT chunks[num_chunks]; + BScalarChunkT chunks[num_chunks]; } fragB{}; - const BScalarFragT* fragPtr; + const BScalarChunkT* fragPtr; - for(index_t chunk = 0; chunk < num_chunks; chunk++) + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K) / packed_size_v; + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K) / packed_size_v; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { - fragPtr = - reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); - fragB.chunks[chunk] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragB.chunks[chunk_idx] = *fragPtr; } return fragB.frag; @@ -904,20 +921,22 @@ template -__global__ void matmul(const AType* a, const BType* b, CType* c) +__global__ void matmul(const typename packed_type::type* a, + const typename packed_type::type* b, + CType* c) { + using PackedAType = typename packed_type::type; + constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + constexpr auto packed_size_b = packed_type::packed_size; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -931,11 +950,11 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) // Load the inputs. if constexpr(is_same_v) { - fragA = load_A_row_major(a); + fragA = load_A_row_major(a); } else { - fragA = load_A_col_major(a); + fragA = load_A_col_major(a); } if constexpr(is_same_v) @@ -944,7 +963,7 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) } else { - fragB = load_B_col_major(b); + fragB = load_B_col_major(b); } // Matrix multiply-accumulate using MFMA units @@ -979,21 +998,24 @@ template -__global__ void -matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +__global__ void matmul(const packed_type_t* a, + const ScaleType* xa, + const packed_type_t* b, + const ScaleType* xb, + CType* c) { + using PackedAType = packed_type_t; + constexpr auto packed_size_a = packed_size_v; + using PackedBType = packed_type_t; + constexpr auto packed_size_b = packed_size_v; + constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; - using BFragT = - vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -1011,9 +1033,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, // Load the inputs. if constexpr(is_same_v) { - fragA = - load_mx_A_row_major( - a, xa, fragXa); + fragA = load_mx_A_row_major(a, xa, fragXa); } else { @@ -1026,9 +1052,13 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, } else { - fragB = - load_mx_B_col_major( - b, xb, fragXb); + fragB = load_mx_B_col_major(b, xb, fragXb); } // Scaled Matrix multiply-accumulate using MFMA units @@ -1151,6 +1181,11 @@ template struct TestMXMFMA { + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1167,11 +1202,11 @@ struct TestMXMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor a_scales( f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor b_scales( f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); @@ -1183,51 +1218,44 @@ struct TestMXMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.5f}}); // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); break; case 2: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; case 3: // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr))); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); - b_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - break; - case 4: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); - a_scales.GenerateTensorValue( - GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1, time(nullptr) / 2)); b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + default: // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] a_scales.GenerateTensorValue( - GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] b_scales.GenerateTensorValue( GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] @@ -1272,9 +1300,9 @@ struct TestMXMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); const Tensor& a_scales = std::get<1>(host_tensors); - const Tensor& b = std::get<2>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); const Tensor& b_scales = std::get<3>(host_tensors); Tensor& c_host = std::get<4>(host_tensors); Tensor& c_device = std::get<5>(host_tensors); @@ -1356,6 +1384,12 @@ template struct TestMFMA { + + using PackedAType = typename packed_type::type; + static constexpr auto packed_size_a = packed_type::packed_size; + using PackedBType = typename packed_type::type; + static constexpr auto packed_size_b = packed_type::packed_size; + auto PrepareGemmTensors(const GemmParams& params, index_t init) { auto f_host_tensor_descriptor = @@ -1372,9 +1406,9 @@ struct TestMFMA } }; - Tensor a_m_k( + Tensor a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_n_k( + Tensor b_n_k( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); @@ -1384,34 +1418,30 @@ struct TestMFMA switch(init) { case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{0.015625f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{0.625f}); // NOTE: not all numbers are representable in FP8, BF8, etc. - b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); break; case 1: // results in C = {K} - a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); break; case 2: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_3{-5, 5}); break; case 3: - // expect small round off errors - a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); - b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); - break; - case 4: - // FP4 values case - a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + // expect small round off errors that lead to FP8MFMA32x32x64 failures + a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); + b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + default: - // all initial values are representable in FP8, BF8 - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + // all initial values are representable in FP8/6, BF8/6 FP4 is missing 5 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); // Z[-6,6] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-6, 7}); break; } @@ -1453,10 +1483,10 @@ struct TestMFMA auto host_tensors = PrepareGemmTensors(params, init); - const Tensor& a = std::get<0>(host_tensors); - const Tensor& b = std::get<1>(host_tensors); - Tensor& c_host = std::get<2>(host_tensors); - Tensor& c_device = std::get<3>(host_tensors); + const Tensor& a = std::get<0>(host_tensors); + const Tensor& b = std::get<1>(host_tensors); + Tensor& c_host = std::get<2>(host_tensors); + Tensor& c_device = std::get<3>(host_tensors); using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -1464,8 +1494,8 @@ struct TestMFMA auto b_element_op = PassThrough{}; auto c_element_op = PassThrough{}; - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm Date: Mon, 19 May 2025 17:29:51 -0700 Subject: [PATCH 130/443] Use new mfma instructions for FP8 on gfx950 (#2202) * Add logic to use new mfma instructions for fp8 bf8 * Fix example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 on gfx950 and run clang format * Update include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> * Fix intrin_mfma f8 calls due to merge mistake --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> --- example/01_gemm/gemm_xdl_fp8.cpp | 2 + ...ipeline_xdlops_b_preshuffle_dequant_v3.hpp | 4 +- ...iple_d_welford_first_half_xdl_cshuffle.hpp | 16 ++- ...wise_batched_gemm_gemm_xdl_cshuffle_v1.hpp | 11 +- ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 33 +++-- ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 19 ++- ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 11 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 11 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 20 +-- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 11 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 20 +-- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 19 +-- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 16 ++- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 11 +- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 32 +++-- ...emm_split_k_multiple_d_xdl_cshuffle_v2.hpp | 16 ++- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 13 +- .../gridwise_gemm_xdl_cshuffle_streamk_v3.hpp | 13 +- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 16 ++- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 16 ++- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 13 +- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 20 ++- .../gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp | 13 +- ...ridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp | 13 +- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 14 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 14 +- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 25 +++- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 11 +- ...ridwise_gemm_xdl_waveletmodel_cshuffle.hpp | 17 ++- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 17 ++- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 11 +- .../gpu/grid/gridwise_moe_gemm.hpp | 29 ++-- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 131 +++++++++++++++--- include/ck/utility/amd_xdlops.hpp | 90 ++++++++++++ 34 files changed, 548 insertions(+), 180 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 3c75a44d21..0c51a58037 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + // this instance has been tested working on gfx950 + // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 4be4e321d3..e5fe92a50d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -124,7 +124,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}; + static constexpr index_t PrefetchStages = 2; static constexpr index_t PrefillStages = 1; static constexpr index_t GlobalBufferNum = 1; diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index d728360c55..02dba97430 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -519,13 +519,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index 50b4a734fa..258d0ad0ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -452,13 +452,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 79a9410898..53a45c7f16 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -365,16 +365,20 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr auto mfma = MfmaSelector::selected_mfma; - constexpr auto N3 = mfma.num_groups_per_blk; - constexpr auto N5 = mfma.group_size; + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( d0_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple( @@ -657,16 +661,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_A0K1_B0K1 <= 4) || - (is_same::value && lcm_A0K1_B0K1 <= 8)) + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_A0K1_B0K1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_A0K1_B0K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index d15767f658..0f2085525f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -347,11 +347,15 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; constexpr auto mfma = - MfmaSelector::selected_mfma; + MfmaSelector:: + selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N4 = mfma.num_input_blks; constexpr auto N5 = mfma.group_size; @@ -564,13 +568,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index a11d696019..33b9199ea5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -473,13 +473,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_v2< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index ab97a940a8..f406bfb95a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -502,13 +502,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 79ab3acd92..054aca2936 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -679,17 +679,19 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + static constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 0e51c6904c..127d889572 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -468,13 +468,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index a3301dd932..be0fff087e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -647,17 +647,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 57b9b02548..7781d1def3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -605,17 +605,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + constexpr auto is_scale_mfma = false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 88d6be234c..5815eb5b0b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -603,13 +603,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 56581256dc..db227bb7ef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -455,13 +455,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 23b4aec3b0..70301c326a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -585,13 +585,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -1018,13 +1024,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = - math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index 44c1e936bd..f64838ea4e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -599,13 +599,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index d37b3cd38e..4d3ae93659 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -83,13 +83,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index e5e32a8535..4e72255d31 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -144,13 +144,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 240bc464e1..7edcd7270f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -814,13 +814,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index c7d44e842d..f92268265f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -873,13 +873,19 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 29150c0688..0dbbc2a5e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -255,13 +255,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index a22fc06a50..cfa8bfeb2a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -148,13 +148,21 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - using mfma_selector = MfmaSelector; + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = true; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); static constexpr index_t KRepeat = KPerBlock / KLane / KPack; static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 7124687d5d..93c1779a80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -160,13 +160,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index ac3e821340..97d0e2a4eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -198,13 +198,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 4163d1d01a..38ce9536ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -183,14 +183,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 21812380c2..ef84dd182a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -153,14 +153,20 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - + static constexpr auto is_scale_mfma = false; static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index c0d9464136..8fb955c561 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -164,12 +164,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); static constexpr index_t KPackPerGroup = KPack / KGroup; static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; static constexpr index_t NLane = NPerXdl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index b435fd5d5a..67fb4d651e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -493,13 +493,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index ad65e75ef9..50363d832e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -491,13 +491,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t KPack = math::max( - lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 168c553180..b7947309e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -744,14 +744,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && K1 <= 4) || - (is_same::value && K1 <= 8)) + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) ? true : false; - - constexpr index_t KPack = math::max( - K1, - MfmaSelector:: - selected_mfma.k_per_blk); + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1::value || is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8)) + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) ? true : false; - constexpr index_t k_pack = math::max( + constexpr auto is_scale_mfma = false; + constexpr index_t k_pack = math::max( lcm_AK1_BK1, - MfmaSelector::selected_mfma - .k_per_blk); + MfmaSelector:: + selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - static constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); - static constexpr index_t KLane = - mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; // static constexpr index_t NumTokens = 1; static constexpr index_t SortedTileSize = MPerBlock; diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 06268f3cfb..b825d7ab69 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1117,12 +1117,31 @@ struct MfmaSelector #endif } + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1136,11 +1155,21 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif + } + template <> constexpr auto GetMfma() { @@ -1166,41 +1195,101 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8bf8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_32x32x16bf8f8; } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8f8; +#endif + } + + template <> + constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32bf8f8; } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif + } + static constexpr auto selected_mfma = mfma_type::value || - is_same::value) && - KPack <= 4) || - (is_same::value && KPack <= 8)) - ? true - : false, - is_scale_mfma > {}; + // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942- + // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t, + // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + KPack <= 4) || + (is_same::value && KPack <= 8) || + ((is_same::value || is_same::value) && KPack < 32) || + is_same::value) + ? true + : false; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ad48389625..ed3354dfb5 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -533,6 +533,50 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { @@ -1118,6 +1162,52 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) { From 0970f22221917d559d0bd72c1065303f08a70450 Mon Sep 17 00:00:00 2001 From: Jan Patrick Lehr Date: Tue, 20 May 2025 02:30:15 +0200 Subject: [PATCH 131/443] [CMake] Disable newly added compiler warning -Wnrvo (#2210) Recently a new warning was added to Clang to warn when no copy-elision on return happens. That prevents our CK build. This disables the warning. --- CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e12462a41..13606975c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,9 @@ add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) add_compile_options(-Wno-unique-object-duplication) +# Recent change in compiler makes this warning ON by default, which led to compile errors. +add_compile_options(-Wno-nrvo) + if(NOT DISABLE_DL_KERNELS) add_definitions(-DDL_KERNELS) set(DL_KERNELS "ON") From c4929225f60f56e3a9320547dcfdff30a77f0aa3 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Mon, 19 May 2025 19:31:04 -0500 Subject: [PATCH 132/443] remove debug statements from CMakeLists (#2204) --- example/CMakeLists.txt | 2 -- test/CMakeLists.txt | 4 ---- 2 files changed, 6 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 996a543ecc..9c30a2e255 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -135,11 +135,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() #message("add_example returns ${result}") if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${EXAMPLE_NAME}) elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${EXAMPLE_NAME}) endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 69ffb94488..5ea61d2dfc 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -121,11 +121,9 @@ function(add_test_executable TEST_NAME) #message("add_test returns ${result}") set(result ${result} PARENT_SCOPE) if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - message("adding to SMOKE TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${TEST_NAME}) elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - message("Adding to REGRESSION TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${TEST_NAME}) endif() @@ -222,11 +220,9 @@ function(add_gtest_executable TEST_NAME) #message("add_gtest returns ${result}") set(result ${result} PARENT_SCOPE) if(result EQUAL 0 AND NOT "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - #message("adding to smoke test FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${TEST_NAME}) elseif(result EQUAL 0 AND "${TEST_NAME}" IN_LIST REGRESSION_TESTS) - #message("Adding to REGRESSION TEST FILTER ${TEST_NAME}") set_tests_properties(${TEST_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${TEST_NAME}) endif() From d1e6f0982d04d6b356f001da731dd5e315f78812 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Tue, 20 May 2025 17:18:57 +0300 Subject: [PATCH 133/443] [CK_TILE] Grouped GEMM tile loop (#2146) * Add trait to use a persistent kernel and split the entrypoints in grouped gemm * Some helper functions for persistent kernel case * Get max occupancy grid using device properties * Implement tile loop in main entry point to grouped gemm * Enable GridSize() on device * Handle offset tile index using real current block index * Add persistent kernel choice to grouped gemm example * Use a for-loop for iterating over the group * Reduce VGPR spills by early-exit * Enable persistent kernel choice in grouped_gemm example * Add persistent kernel option to grouped_gemm test * Fix formatting with remod.py * Remove GridUpdateBlocks as blocks are now iteratively computed * Add comment about VGPR spilling * Fix formatting * Use CK_TILE_HOST instead of __host__ * Enable all Row/Col combinations in grouped gemm unit test * Add some KBatch=2 cases to grouped gemm tests * Fix SplitK for grouped gemm * Enable pipeline hotloop/tailnumber selection in-kernel for grouped gemm * Add type traits * Split examples to regular and tileloop * Formatting * Use hipExtStreamGetCUMask to get current active CUs for the given stream * Align test and example kernel config, and disable validation for splitk repeats * Remove debug options from CMakeLists.txt * Separate the code paths for persistent/non-persistent in test * Fix formatting * Address review comments --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../ck_tile/17_grouped_gemm/CMakeLists.txt | 2 +- .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 132 ++++----- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 17 +- .../17_grouped_gemm/grouped_gemm_tileloop.cpp | 174 ++++++++++++ .../run_grouped_gemm_example.inc | 98 +++++-- include/ck_tile/core/utility/type_traits.hpp | 11 + include/ck_tile/host/stream_utils.hpp | 45 ++++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 18 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 252 ++++++++++++++++-- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 12 +- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 0 .../ops/gemm/pipeline/tile_gemm_traits.hpp | 24 +- .../grouped_gemm/test_grouped_gemm.cpp | 30 ++- .../test_grouped_gemm_ut_cases.inc | 30 ++- .../grouped_gemm/test_grouped_gemm_util.hpp | 209 +++++++++++++-- 15 files changed, 908 insertions(+), 146 deletions(-) create mode 100644 example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp create mode 100644 include/ck_tile/host/stream_utils.hpp mode change 100755 => 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index d34013dd6c..79df4e624d 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) - +add_executable(tile_example_grouped_gemm_tileloop EXCLUDE_FROM_ALL grouped_gemm_tileloop.cpp) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 9b134ff779..61193e2e29 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,15 +16,10 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -std::size_t get_workspace_size(const std::vector& gemm_descs) -{ - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); -} - template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, - void* p_workspace_) + void* kargs_ptr) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -114,70 +109,76 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); - return ave_time; - }; + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) @@ -317,4 +318,5 @@ float grouped_gemm(const std::vector& gemm_descs, #include "run_grouped_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +constexpr bool Persistent = false; +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 4fec329c2f..77db182c72 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -70,14 +70,25 @@ auto create_args(int argc, char* argv[]) .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") - .insert("group_count", "8", "group count."); + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -std::size_t get_workspace_size(const std::vector& gemm_descs); +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); +} +template float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, - void* p_workspace_); + void* kargs_ptr); + +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk = false); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp new file mode 100644 index 0000000000..5c0cb92683 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -0,0 +1,174 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm.hpp" + +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + +#include "run_grouped_gemm_example.inc" + +constexpr bool Persistent = true; +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index f068510d26..a01d8178cc 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -30,20 +30,60 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(int n_warmup, int n_repeat, int group_count, const std::vector& args) { - + // Workspace memory allocated to hold the gemm descriptions. ck_tile::DeviceMem gemm_workspace; gemm_workspace.Realloc(get_workspace_size(args)); - float ave_time = grouped_gemm( - args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, - gemm_workspace.GetDeviceBuffer()); + float ave_time = 0; + if constexpr(!Persistent) + { + // Regular version of grouped gemm + ave_time = grouped_gemm( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.c_ptr, + arg.M, + arg.N, + arg.K, + arg.stride_A, + arg.stride_B, + arg.stride_C, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop( + stream, group_count, kargs_ptr, splitk); + } std::string op_name{"Grouped Gemm"}; @@ -66,7 +106,7 @@ float invoke_gemm(int n_warmup, return ave_time; } -template +template int run_grouped_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -87,6 +127,15 @@ int run_grouped_gemm_example_with_layouts(int argc, const int group_count = arg_parser.get_int("group_count"); const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } std::vector Ms = arg_parser.get_int_vec("Ms"); std::vector Ns = arg_parser.get_int_vec("Ns"); @@ -102,7 +151,7 @@ int run_grouped_gemm_example_with_layouts(int argc, { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); - Ks.push_back(256 + 64 * i); + Ks.push_back(512 + 128 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -150,8 +199,8 @@ int run_grouped_gemm_example_with_layouts(int argc, << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k_tensors[i]); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); @@ -169,13 +218,11 @@ int run_grouped_gemm_example_with_layouts(int argc, const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - // TODO Add support for kbatch > 1 in grouped gemm - static constexpr ck_tile::index_t k_batch = 1; gemm_descs.push_back( - {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } - invoke_gemm(warmup, repeat, group_count, gemm_descs); + invoke_gemm(warmup, repeat, group_count, gemm_descs); for(int i = 0; i < group_count; i++) { @@ -183,7 +230,7 @@ int run_grouped_gemm_example_with_layouts(int argc, } bool pass{true}; - if(arg_parser.get_int("validate")) + if(validate) { for(int i = 0; i < group_count; ++i) { @@ -194,7 +241,7 @@ int run_grouped_gemm_example_with_layouts(int argc, a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value); pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref, "Error: Incorrect results!", @@ -211,6 +258,7 @@ int run_grouped_gemm_example_with_layouts(int argc, return pass; } +template int run_grouped_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -227,12 +275,20 @@ int run_grouped_gemm_example(int argc, char* argv[]) if(a_layout == "R" && b_layout == "C") { - return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_grouped_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); } - // else if(a_layout == "R" && b_layout == "R") - // { - // return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); - // } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index b432cfcef7..2e82e21ba1 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -127,4 +127,15 @@ struct is_any_of { }; +// Helper to check if a type is a specialization of a given template +template class RefTemplate> +struct is_specialization_of : std::false_type +{ +}; + +template