Merge branch 'develop' into ginolu/add_wgmfma_dispatcher

This commit is contained in:
Gino Lu
2025-09-08 19:10:23 -05:00
147 changed files with 10722 additions and 1484 deletions

View File

@@ -5,7 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.0.0
### Added
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.

View File

@@ -22,6 +22,9 @@ Xiaoyan Zhou, 2020
[Jianfeng Yan](https://github.com/j4yan), 2021-2022
[Jun Liu](https://github.com/junliume), 2021-2024
[John Shumway](https://github.com/shumway), [Vidyasagar Ananthan](https://github.com/vidyasagar-amd), [Christopher Millette](https://github.com/cgmillette), [Maksim Podkorytov](https://github.com/tenpercent), [Thomas Ning](https://github.com/ThomasNing),[Andriy Roshchenko](https://github.com/andriy-ca), [Aviral Goel](https://github.com/AviralGoelAMD), [Cong Ma](https://github.com/CongMa13),[Thrupti Raj Lakshmana Gowda](https://github.com/ThruptiRajLakshmanaGowda), [Emily Martins](https://github.com/ecamartins), [Khushbu Agarwal](https://github.com/amd-khushbu), [Sudhir Kylasa](https://github.com/kylasa), [Jia Luo](https://github.com/JiaLuo-CAN), 2025-
## Product Manager
[John Afaganis](https://github.com/afagaj)

41
Jenkinsfile vendored
View File

@@ -33,9 +33,6 @@ def nthreads() {
def nproc = sh(returnStdout: true, script: 'nproc')
echo "Number of cores: ${nproc}"
def n = nproc.toInteger()
if (n > 32){
n /= 2
}
if (n > 64){
n = 64
}
@@ -1357,22 +1354,9 @@ pipeline {
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm_fp8_rcr && \
./bin/benchmark_gemm_fp8_rcr && \
ninja -j64 benchmark_gemm_fp16_rcr && \
./bin/benchmark_gemm_fp16_rcr && \
ninja -j64 benchmark_gemm_fp8_crr && \
./bin/benchmark_gemm_fp8_crr && \
ninja -j64 benchmark_gemm_fp16_crr && \
./bin/benchmark_gemm_fp16_crr && \
ninja -j64 benchmark_gemm_fp8_ccr && \
./bin/benchmark_gemm_fp8_ccr && \
ninja -j64 benchmark_gemm_fp16_ccr && \
./bin/benchmark_gemm_fp16_ccr && \
ninja -j64 benchmark_gemm_fp8_rrr && \
./bin/benchmark_gemm_fp8_rrr && \
ninja -j64 benchmark_gemm_fp16_rrr && \
./bin/benchmark_gemm_fp16_rrr && \
ninja -j64 benchmark_gemm_all && \
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
--warmup 5 --repeat 5 --verbose --json results.json && \
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
@@ -1405,22 +1389,9 @@ pipeline {
-D GEMM_MULTI_D_DATATYPE="fp16" \
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm_fp8_rcr && \
./bin/benchmark_gemm_fp8_rcr && \
ninja -j64 benchmark_gemm_fp16_rcr && \
./bin/benchmark_gemm_fp16_rcr && \
ninja -j64 benchmark_gemm_fp8_crr && \
./bin/benchmark_gemm_fp8_crr && \
ninja -j64 benchmark_gemm_fp16_crr && \
./bin/benchmark_gemm_fp16_crr && \
ninja -j64 benchmark_gemm_fp8_ccr && \
./bin/benchmark_gemm_fp8_ccr && \
ninja -j64 benchmark_gemm_fp16_ccr && \
./bin/benchmark_gemm_fp16_ccr && \
ninja -j64 benchmark_gemm_fp8_rrr && \
./bin/benchmark_gemm_fp8_rrr && \
ninja -j64 benchmark_gemm_fp16_rrr && \
./bin/benchmark_gemm_fp16_rrr && \
ninja -j64 benchmark_gemm_all && \
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
--warmup 5 --repeat 5 --verbose --json results.json && \
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \

View File

@@ -39,6 +39,7 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab
* :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>`
* :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>`
* :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>`
* :doc:`Composable Kernel glossary <./reference/Composable-Kernel-Glossary>`
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/contributing.html>`_.

View File

@@ -0,0 +1,256 @@
.. meta::
:description: Composable Kernel glossary of terms
:keywords: composable kernel, glossary
***************************************************
Composable Kernel glossary
***************************************************
.. glossary::
:sorted:
arithmetic logic unit
The arithmetic logic unit (ALU) is the GPU component responsible for arithmetic and logic operations.
compute unit
The compute unit (CU) is the parallel vector processor in an AMD GPU with multiple :term:`ALUs<arithmetic logic unit>`. Each compute unit will run all the :term:`wavefronts<wavefront>` in a :term:`work group>`. A compute unit is equivalent to NVIDIA's streaming multiprocessor.
matrix core
A matrix core is a specialized GPU unit that accelerate matrix operations for AI and deep learning tasks. A GPU contains multiple matrix cores.
register
Registers are the fastest tier of memory. They're used for storing temporary values during computations and are private to the :term:`work-items<work-item>` that use them.
VGPR
See :term:`vector general purpose register`.
vector general purpose register
A vector general purpose register (VGPR) is a :term:`register` that stores individual thread data. Each thread in a :term:`wave<wavefront>` has its own set of VGPRs for private variables and calculations.
SGPR
See :term:`scalar general purpose register`.
scalar general purpose register
A scalar general purpose register (SGPR) is a :term:`register` shared by all the :term:`work items<work item>` in a :term:`wave<wavefront>`. SGPRs are used for constants, addresses, and control flow common across the entire wave.
LDS
See :term:`local data share`.
local data share
Local data share (LDS) is high-bandwidth, low-latency on-chip memory accessible to all the :term:`work-items<work-item>` in a :term:`work group`. LDS is equivalent to NVIDIA's shared memory.
LDS banks
LDS banks are a type of memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. LDS banks are used to prevent memory access conflicts and improve bandwidth when LDS is used.
global memory
The main device memory accessible by all threads, offering high capacity but higher latency than shared memory.
pinned memory
Pinned memory is :term:`host` memory that is page-locked to accelerate transfers between the CPU and GPU.
dense tensor
A dense tensor is a tensor where most of its elements are non-zero. Dense tensors are typically stored in a contiguous block of memory.
sparse tensor
A sparse tensor is a tensor where most of its elements are zero. Typically only the non-zero elements of a sparse tensor and their indices are stored.
host
Host refers to the CPU and the main memory system that manages GPU execution. The host is responsible for launching kernels, transferring data, and coordinating overall computation.
device
Device refers to the GPU hardware that runs parallel kernels. The device contains the :term:`compute units<compute unit>`, memory hierarchy, and specialized accelerators.
work-item
A work-item is the smallest unit of parallel execution. A work-item runs a single independent instruction stream on a single data element. A work-item is equivalent to an NVIDIA thread.
wavefront
Also referred to as a wave, a wavefront is a group of :term:`work-items<work-item>` that run the same instruction. A wavefront is equivalent to an NVIDIA warp.
work group
A work group is a collection of :term:`work-items<work-item>` that can synchronize and share memory. A work group is equivalent to NVIDIA's thread block.
grid
A grid is a collection of :term:`work groups<work group>` that run a kernel. Each work group within the grid operates independently and can be scheduled on a different :term:`compute unit`. A grid can be organized into one, two, or three dimensions. A grid is equivalent to an NVIDIA thread block.
block Size
The block size is the number of :term:`work-items<work-item>` in a :term:`compute unit`.
SIMT
See :term:`single-instruction, multi-thread`
single-instruction, multi-thread
Single-instruction, multi-thread (SIMT) is a parallel computing model where all the :term:`work-items<work-item>` within a :term:`wavefront` run the same instruction on different data.
SIMD
See :term:`single-instruction, multi-data`
single-instruction, multi-data
Single-instruction, multi-data (SIMD) is a parallel computing model where the same instruction is run with different data simultaneously.
occupancy
The ratio of active :term:`wavefronts<wavefront>` to the maximum possible number of wavefronts.
kernel
A kernel is a function that runs an :term:`operation` or a collection of operations. A kernel will run in parallel on several :term:`work-items<work-item>` across the GPU. In Composable Kernel, kernels require :term:`pipelines<pipeline>`.
operation
An operation is a computation on input data.
pipeline
A Composable Kernel pipeline schedules the sequence of operations for a :term:`kernel`, such as the data loading, computation, and storage phases. A pipeline consists of a :term:`problem` and a :term:`policy`.
tile partitioner
The tile partitioner defines the mapping between the :term:`problem` dimensions and GPU hierarchy. It specifies :term:`workgroup`-level :term:`tile` sizes and determines :term:`grid` dimensions by dividing the problem size by the tile sizes.
problem
The problem is the part of the :term:`pipeline` that defines input and output shapes, data types, and mathematical :term:`operations<operation>`.
policy
The policy is the part of the :term:`pipeline` that defines memory access patterns and hardware-specific optimizations.
user customized tile pipeline
A customized :term:`tile` :term:`pipeline` that combines custom :term:`problem` and :term:`policy` components for specialized computations.
user customized tile pipeline optimization
The process of tuning the :term:`tile` size, memory access pattern, and hardware utilization for specific workloads.
tile programming API
The :term:`tile` programming API is Composable Kernel's high-level interface for defining tile-based computations with predefined hardware mappings for data loading and storing.
coordinate transformation primitives
Coordinate transformation primitives are Composable Kernel utilities for converting between different coordinate systems.
reference kernel
A reference :term:`kernel` is a baseline kernel implementation used to verify correctness and performance. Composable Kernel makes two reference kernels, one for CPU and one for GPU, available.
launch parameters
Launch parameters are the configuration values, such as :term:`grid` and :term:`block size`, that determine how a :term:`kernel` is mapped to hardware resources.
memory coalescing
Memory coalescing is an optimization strategy where consecutive :term:`work-items<work-item>` access consecutive memory addresses in such a way that a single memory transaction serves multiple work-items.
alignment
Alignment is a memory management strategy where data structures are stored at addresses that are multiples of a specific value.
bank conflict
A bank conflict occurs when multiple :term:`work-items<work-item>` in a :term:`wavefront` access different addresses that map to the same shared memory bank.
padding
Padding is the addition of extra elements, often zeros, to tensor edges in order to control output size in convolution and pooling, or to align data for memory access.
transpose
Transpose is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
permute
Permute is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
host-device transfer
A host-device transfer is the process of moving data between :term:`host` and :term:`device` memory.
stride
A stride is the step size to move from one element to the next in a specific dimension of a tensor or matrix. In convolution and pooling, the stride determines how far the :term:`kernel` moves at each step.
dilation
Dilation is the spacing between :term:`kernel` elements in convolution :term:`operations<operation>`, allowing the receptive field to grow without increasing kernel size.
Im2Col
Im2Col is a data transformation technique that converts image data to column format.
Col2Im
Col2Im is a data transformation technique that converts column data to image format.
fast changing dimension
The fast changing dimension is the innermost dimension in memory layout.
outer dimension
The outer dimension is the slower-changing dimension in memory layout.
inner dimension
The inner dimension is the faster-changing dimension in memory layout.
tile
A tile is a sub-region of a tensor or matrix that is processed by a :term:`work group` or :term:`work-item`. Rectangular data blocks are the unit of computation and memory transfer in Composable Kernel, and are the basis for tiled algorithms.
block tile
A block tile is a memory :term:`tile` processed by a :term:`work group`.
wave tile
A wave :term:`tile` is a sub-tile processed by a single :term:`wavefront` within a :term:`work group`. The wave tile is the base level granularity of a :term:`single-instruction, multi-thread (SIMD)<single-instruction, multi-thread>` model.
tile distribution
The tile distribution is the hierarchical data mapping from :term:`work-items<work-item>` to data in memory.
tile window
Viewport into a larger tensor that defines the current tile's position and boundaries for computation.
load tile
Load tile is an operation that transfers data from :term:`global memory` or the :term:`load data share` to :term:`vector general purpose registers<vector general purpose register>`.
store tile
Store tile is an operation that transfers data from :term:`vector general purpose registers<vector general purpose register>` to :term:`global memory` or the :term:`load data share`.
descriptor
Metadata structure that defines :term:`tile` properties, memory layouts, and coordinate transformations for Composable Kernel :term:`operations<operation>`.
input
See :term:`problem shape`.
problem shape
The problem shape defines the dimensions and data types of input tensors that define the :term:`problem`.
vector
The vector is the smallest data unit processed by an individual :term:`work-item`. A vectors is typically four to sixteen elements, depending on data type and hardware.
elementwise
An elementwise :term:`operation` is an operation applied to each tensor element independently.
epilogue
The epilogue is the final stage of a kernel. Activation functions, bias, and other post-processing steps are applied in the epilogue.
Add+Multiply
See :term:`fused add multiply`.
fused add multiply
A common fused :term:`operation` in machine language and linear algebra, where an :term:`elementwise` addition is immediately followed by a multiplication. Fused add multiply is often used for bias and scaling in neural network layers.
MFMA
See :term:`matrix fused multiply-add`.
matrix fused multiply-add
Matrix fused multiply-add (MFMA) is a :term:`matrix core` instruction for GEMM :term:`operations<operation>`.
GEMM
See :term:`general matrix multiply`.
general matrix multiply
A general matrix multiply (GEMM) is a Core matrix :term:`operation` in linear algebra and deep learning. A GEMM is defined as :math:`C = {\alpha}AB + {\beta}C`, where :math:`A`, :math:`B`, and :math:`C` are matrices, and :math:`\alpha` and :math:`\beta` are scalars.
VGEMM
See :term:`naive GEMM`.
vanilla GEMM
See :term:`naive GEMM`.
naive GEMM
The naive GEMM, sometimes referred to as a vanilla GEMM or VGEMM, is the simplest form of :term:`GEMM` in Composable Kernel. The naive GEMM is defined as :math:`C = AB`, where :math:`A`, :math:`B`, and :math:`C` are matrices. The naive GEMM is the baseline GEMM that all other GEMM :term:`operations<operation>` build on.
GGEMM
See :term:`grouped GEMM`.
grouped GEMM
A :term:`kernel` that calls multiple :term:`VGEMMs<naive GEMM>`. Each call can have a different :term:`problem shape`.
batched GEMM
A :term:`kernel` that calls :term:`VGEMMs<naive GEMM>` with different batches of data. All the data batches have the same :term:`problem shape`.
Split-K GEMM
Split-K GEMM is a parallelization strategy that partitions the reduction dimension (K) of a :term:`GEMM` across multiple :term:`compute units<compute unit>`, increasing parallelism for large matrix multiplications.
GEMV
See :term:`general matrix vector multiplication`
general matrix vector multiplication
General matrix vector multiplication (GEMV) is an :term:`operation` where a matrix is multiplied by a vector, producing another vector.

View File

@@ -34,8 +34,14 @@ subtrees:
title: Composable Kernel vector utilities
- file: reference/Composable-Kernel-wrapper.rst
title: Composable Kernel wrapper
- file: doxygen/html/namespace_c_k.rst
title: CK API reference
- file: doxygen/html/namespaceck__tile.rst
title: CK Tile API reference
- file: doxygen/html/annotated.rst
title: Composable Kernel class list
title: Full API class list
- file: reference/Composable-Kernel-Glossary.rst
title: Glossary
- caption: About
entries:

View File

@@ -70,3 +70,5 @@ example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpresh
example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)

View File

@@ -0,0 +1,267 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F16;
using B0DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Row;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct AddAdd
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c + d0 + d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3
// clang-format off
//#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S<C, D..>| | |
< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = K;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -184,7 +184,6 @@ int main(int argc, char* argv[])
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
@@ -220,11 +219,12 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
@@ -233,8 +233,6 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});

View File

@@ -0,0 +1,22 @@
add_custom_target(example_gemm_add_xdl)
add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16)
add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16)
add_custom_target(example_gemm_add_wmma)
add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp)
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_bf16)
add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp)
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_fp16)

View File

@@ -0,0 +1,114 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <stdexcept>
#include <string>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row_Tuple = ck::Tuple<Row>;
using F16_Tuple = ck::Tuple<F16>;
using BF16_Tuple = ck::Tuple<BF16>;
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideD = std::stoi(argv[9]);
problem_size.StrideE = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD,"
"StrideE"
<< std::endl;
return false;
}
return true;
}

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = BF16;
using DsDataType = BF16_Tuple;
using EDataType = BF16;
using Row_Tuple = ck::Tuple<Row>;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using DsLayout = Row_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
F32,
F32,
PassThrough,
PassThrough,
Add,
GemmSpec,
128,
128,
64,
64,
8,
8,
16,
16,
4,
2,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<4, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 4>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
// clang-format on
#include "run_gemm_add_example_wmma.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = F16_Tuple;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using DsLayout = Row_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
F32,
F32,
PassThrough,
PassThrough,
Add,
GemmSpec,
128,
128,
64,
64,
8,
8,
16,
16,
4,
2,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<4, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 4>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
// clang-format on
#include "run_gemm_add_example_wmma.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = BF16;
using EDataType = BF16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
#include "run_gemm_add_example_xdl.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
#include "run_gemm_add_example_xdl.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }

View File

@@ -0,0 +1,145 @@
#pragma once
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
bool pass = true;
if(config.do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
}
return pass;
}
bool run_gemm_add_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config);
}

View File

@@ -0,0 +1,144 @@
#pragma once
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
bool pass = true;
if(config.do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
}
return pass;
}
bool run_gemm_add_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config);
}

View File

@@ -0,0 +1,15 @@
add_custom_target(example_gemm_add_relu_xdl)
add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_fp16)
add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_bf16)
add_custom_target(example_gemm_add_relu_wmma)
add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp)
add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_bf16)
add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp)
add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_fp16)

View File

@@ -0,0 +1,114 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <stdexcept>
#include <string>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row_Tuple = ck::Tuple<Row>;
using F16_Tuple = ck::Tuple<F16>;
using BF16_Tuple = ck::Tuple<BF16>;
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideD = std::stoi(argv[9]);
problem_size.StrideE = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD,"
"StrideE"
<< std::endl;
return false;
}
return true;
}

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = BF16;
using DsDataType = BF16_Tuple;
using EDataType = BF16;
using Row_Tuple = ck::Tuple<Row>;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using DsLayout = Row_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddRelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
F32,
F32,
PassThrough,
PassThrough,
AddRelu,
GemmSpec,
128,
128,
64,
64,
8,
8,
16,
16,
4,
2,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<4, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 4>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
// clang-format on
#include "run_gemm_add_relu_example_wmma.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -0,0 +1,76 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using DsDataType = F16_Tuple;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using DLayout = Row;
using DsLayout = Row_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddRelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
F32,
F32,
PassThrough,
PassThrough,
AddRelu,
GemmSpec,
128,
128,
64,
64,
8,
8,
16,
16,
4,
2,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
0,
S<4, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
0,
1,
1,
S<1, 32, 1, 4>,
S<8, 8, 8>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1>;
// clang-format on
#include "run_gemm_add_relu_example_wmma.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = BF16;
using EDataType = BF16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddRelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
#include "run_gemm_add_relu_example_xdl.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DLayout = Row;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddRelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
1,
256,
256,
128,
32,
8,
8,
32,
32,
4,
2,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
1,
1,
1,
S<1, 32, 1, 8>,
8>;
#include "run_gemm_add_relu_example_xdl.inc"
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }

View File

@@ -0,0 +1,146 @@
#pragma once
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
bool pass = true;
if(config.do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
}
return pass;
}
bool run_gemm_add_relu_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) &&
run_gemm_add_relu(problem_size, config);
}

View File

@@ -0,0 +1,145 @@
#pragma once
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(config.init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 1>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< device_op.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
bool pass = true;
if(config.do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CShuffleDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
}
return pass;
}
bool run_gemm_add_relu_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) &&
run_gemm_add_relu(problem_size, config);
}

View File

@@ -125,7 +125,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dvpad},
{F_deterministic},
{F_trload},
{F_maxq}>;
{F_maxq},
{F_bn0}>;
#include <iostream>
@@ -218,10 +219,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
@@ -386,6 +387,7 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]
@@ -519,7 +521,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
{F_deterministic},
{F_bn0}>;
#include <iostream>
@@ -656,6 +659,17 @@ class FmhaBwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
@property
def extra_cond(self) -> str:
if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
return "&& (a.seqlen_k <= 256)"
else:
return ""
@property
def convert_dq_bn0(self) -> int:
return self.tile.F_bn0 if self.deterministic == 't' else 0
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
@@ -680,7 +694,7 @@ class FmhaBwdApiTrait:
return 2
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
@@ -708,7 +722,8 @@ class FmhaBwdApiPool:
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
F_convert_dq_bn0=trait.convert_dq_bn0)
i += 1
return inners
@@ -791,6 +806,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
continue # tr_load cannot work with dpad or dvpad
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
@@ -799,9 +817,6 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:

View File

@@ -372,7 +372,8 @@ template <ck_tile::index_t HDim_,
bool kPadDv_,
bool kIsDeterministic_,
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
ck_tile::index_t MaxSeqLenQ_,
ck_tile::index_t kN0>
struct fmha_bwd_dq_dk_dv_traits_
{
};
@@ -412,15 +413,10 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
bool kIsDeterministic_,
ck_tile::index_t kN0>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>

2
example/ck_tile/01_fmha/mask.hpp Executable file → Normal file
View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -91,7 +91,7 @@ int main(int argc, char* argv[])
try
{
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
}
catch(const std::runtime_error& e)
{

View File

@@ -1 +1,3 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)

View File

@@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Parsing Arguments
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group,
### Key Arguments
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure.
- `warmup`: the number of iterations before the actual kernel run time measure
```cpp
// Example
@@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup,
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
### Advanced Features: Preshuffle and Persistence
The grouped GEMM examples include two advanced optimization features:
#### Weight Preshuffle
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
@@ -151,26 +173,44 @@ mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-Ms M dimensions - empty by default. (default:)
-Ns N dimensions - empty by default. (default:)
-Ks K dimensions - empty by default. (default:)
-stride_As Tensor A strides - it is empty by default. (default:)
-stride_Bs Tensor B strides - it is empty by default. (default:)
-stride_Cs Tensor C strides - it is empty by default. (default:)
-a_layout A tensor data layout - Row by default. (default:R)
-b_layout B tensor data layout - Row by default. (default:C)
-c_layout C tensor data layout - Row by default. (default:R)
-validate 0. No validation, 1. Validation on CPU. (default:1)
-warmup number of iterations before benchmark the kernel. (default:10)
-repeat number of iterations to benchmark the kernel. (default:100)
-group_count group count. (default:8)
-kbatch kbatch for SplitK (default:1)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:grouped_gemm.json)
-Ms M dimensions - (Default: empty).
-Ns N dimensions - (Default: empty).
-Ks K dimensions - (Default: empty).
-stride_As Tensor A strides - (Default: empty).
-stride_Bs Tensor B strides - (Default: empty).
-stride_Cs Tensor C strides - (Default: empty).
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
-kbatch kbatch for SplitK (Default: 1).
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
-jsonfile json file name to dump results (Default: grouped_gemm.json).
```
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
don't match `group_count`, the example generates defaults per group index `i` (0-based):
```text
M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
```

View File

@@ -16,6 +16,155 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
@@ -29,16 +178,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
void* kargs_ptr,
bool splitk)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
@@ -124,8 +272,86 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = true;
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
}

View File

@@ -4,6 +4,7 @@
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
@@ -14,6 +15,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
@@ -37,6 +39,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -77,6 +95,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool Persistent = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
@@ -123,6 +143,53 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool kPadK = true;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
@@ -153,9 +220,19 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
@@ -177,7 +254,7 @@ auto create_args(int argc, char* argv[])
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
return std::make_pair(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
@@ -185,7 +262,24 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ADataType,
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -194,7 +288,6 @@ template <typename ADataType,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,

View File

@@ -0,0 +1,234 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop =
// if preshuffle == true then num_loop is recalculated for each group in the kernel code
TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// Preshuffle is supported only for A(Row major), B(column major) input matrices!
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
}

View File

@@ -0,0 +1,136 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#include "ck_tile/host.hpp"
#include "quant_grouped_gemm.hpp"
template <typename GemmConfig,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant;
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
false,
ALayout,
BLayout,
CLayout,
QuantMode,
AQLayout,
BQLayout,
GemmConfig::DoubleSmemBuffer,
true>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr bool transpose_c = false;
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
BDataType,
AccDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
transpose_c,
BDataType,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
QuantGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
GemmUniversalTraits::kQuantType>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ave_time;
}
#include "quant_run_grouped_gemm_example.inc"
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
}

View File

@@ -0,0 +1,157 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
template <>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk = false);

View File

@@ -0,0 +1,443 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
{
// Workspace memory allocated to hold the gemm descriptions.
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
assert(args[0].k_batch == 1);
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.aq_ptr,
arg.bq_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.QK_A,
arg.QK_B,
arg.stride_A,
arg.stride_B,
arg.stride_E,
arg.stride_AQ,
arg.stride_BQ,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(stream, group_count, kargs_ptr);
std::string op_name{"Grouped Gemm"};
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
sizeof(BDataType) * args[j].K * args[j].N +
sizeof(CDataType) * args[j].M * args[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_grouped_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
if(kbatch > 1 && validate && warmup + repeat > 1)
{
std::cout << "WARNING: Data validation enabled with SplitK and more than"
<< "1 warmup/repeat. Disabling validation." << std::endl;
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
ck_tile::index_t AQK, BQK;
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(0);
stride_Bs.push_back(0);
stride_Cs.push_back(0);
stride_AQs.push_back(0);
stride_BQs.push_back(0);
}
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
c_m_n_tensors.reserve(group_count);
aq_tensors.reserve(group_count);
bq_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
c_m_n_dev_buf.reserve(group_count);
aq_dev_buf.reserve(group_count);
bq_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT
BQK = N; // Column quantization: tensor shape [1, N]. Only for NT
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout))));
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
aq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
bq_dev_buf.push_back(
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back({p_a,
p_b,
p_c,
p_aq,
p_bq,
kbatch,
M,
N,
K,
AQK,
BQK,
stride_As[i],
stride_Bs[i],
stride_Cs[i],
stride_AQs[i],
stride_BQs[i]});
}
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
BQLayout,
CLayout>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
}
bool pass{true};
if(validate)
{
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_rowcol_quant<ADataType,
AQDataType,
BDataType,
BQDataType,
AccDataType,
CDataType>(
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
using AQDataType = typename Types::AccDataType;
using BQDataType = typename Types::AccDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
AQDataType,
BDataType,
BQDataType,
CDataType,
AccDataType>(
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -40,7 +40,6 @@ template <typename GemmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
@@ -52,10 +51,10 @@ float invoke_gemm(int n_warmup,
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
if constexpr(!Persistent)
if constexpr(!GemmConfig::Persistent)
{
// Regular version of grouped gemm
ave_time = grouped_gemm<ADataType,
ave_time = grouped_gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -71,14 +70,24 @@ float invoke_gemm(int n_warmup,
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::GemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
@@ -116,8 +125,7 @@ float invoke_gemm(int n_warmup,
return ave_time;
}
template <bool Persistent,
typename GemmConfig,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename CDataType,
@@ -131,12 +139,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
auto [result, arg_parser] = create_args(argc, argv);
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
@@ -165,11 +169,14 @@ int run_grouped_gemm_example_with_layouts(int argc,
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, "
"896, 1280..)"
<< std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
Ks.push_back(512 + 384 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
@@ -198,11 +205,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
@@ -220,15 +228,21 @@ int run_grouped_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));
// Perform preshuffle for B tensor
if constexpr(GemmConfig::Preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n_tensors[i]);
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_shuffle_host));
}
else
{
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_k_n_tensors[i]));
}
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(c_m_n_tensors[i]));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
@@ -240,7 +254,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
float ave_time = invoke_gemm<ADataType,
float ave_time = invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
@@ -248,8 +263,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
Persistent>(warmup, repeat, group_count, gemm_descs);
CLayout>(warmup, repeat, group_count, gemm_descs);
std::string op_name{"Grouped Gemm"};
@@ -289,11 +303,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
pass &=
ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
@@ -315,86 +330,3 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <bool Persistent, typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <bool Persistent, template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -1,6 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <variant>
#include "ck_tile/core/arch/arch.hpp"
auto string_to_datatype(const std::string& datatype)

View File

@@ -158,7 +158,14 @@ bool filter_then_run(const ck_tile::ArgParser& arg_parser)
bool pass = true;
if constexpr(std::is_same_v<XElementwiseOperation, ck_tile::element_wise::UnarySquare> &&
std::is_same_v<XDataType, ck_tile::bf16_t>)
(std::is_same_v<XDataType, ck_tile::bf16_t> ||
std::is_same_v<YDataType, ck_tile::bf16_t>))
{
throw_unsupported();
}
else if constexpr(std::is_same_v<XElementwiseOperation, ck_tile::element_wise::UnaryConvert> &&
(std::is_same_v<XDataType, ck_tile::bf16_t> ||
std::is_same_v<YDataType, ck_tile::bf16_t>))
{
throw_unsupported();
}

View File

@@ -149,50 +149,105 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
#endif
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleDSplitK in contexts where DeviceGemmMultipleD is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD and DeviceGemmMultipleDSplitK is
/// that DeviceGemmMultipleDSplitK::MakeArgumentPointer requires an additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
struct DeviceGemmMultipleDSplitKWrapper : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef CK_CODE_GEN_RTC
virtual std::unique_ptr<BaseArgument>
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
1, // KBatch
a_element_op,
b_element_op,
cde_element_op);
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
virtual int GetPreShuffleParameters() = 0;
#endif
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device

View File

@@ -40,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
@@ -62,14 +62,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
#if defined(__gfx11__)
}
#endif
@@ -272,11 +276,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -311,7 +317,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -336,17 +342,25 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
index_t BatchStrideC_,
index_t Batch_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
std::array<const void*, 0>{}, // p_ds_grid_
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
k_batch_,
a_element_op_,
b_element_op_,
cde_element_op_,
is_reduce_),
Batch(Batch_),
compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
@@ -443,7 +457,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg_.p_c_grid,
hipMemsetAsync(arg_.p_e_grid,
0,
arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -469,7 +483,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
// Note: This seems incorrect for non-contiguous memory layouts for C
// (padding, gaps).
HIP_CHECK_ERROR(
hipMemsetAsync(arg.p_c_grid,
hipMemsetAsync(arg.p_e_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
@@ -658,7 +672,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BatchStrideB,
BatchStrideC,
Batch,
1 /* KBatch */};
1, /* KBatch */
AElementwiseOperation{},
BElementwiseOperation{},
CElementwiseOperation{}};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -694,7 +711,10 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm<ALayout,
BatchStrideB,
BatchStrideC,
Batch,
1); // KBatch
1,
AElementwiseOperation{},
BElementwiseOperation{},
CElementwiseOperation{}); // KBatch
}
// polymorphic

View File

@@ -0,0 +1,410 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction.
/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultipleD_Wmma_CShuffleV3
: public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
DsDataType,
EDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<EDataType*>(p_e),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultipleD_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0];
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
str << std::string(DLayout::name)[0];
});
str << std::string(ELayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma << "x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat << "x" << NRepeat << ", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -177,15 +177,16 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -220,7 +221,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -230,21 +231,24 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<>,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
@@ -275,11 +279,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -295,20 +313,25 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t StrideB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
KBatch);
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic

View File

@@ -89,11 +89,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>, // DsDataType
CDataType,
AElementwiseOperation,
BElementwiseOperation,
@@ -130,7 +132,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -140,21 +142,24 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
Tuple<>,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
@@ -188,23 +193,25 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
std::array<const void*, 0>{}, // p_ds_grid_
p_c,
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
@@ -228,12 +235,14 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
std::array<const void*, 0>{}, // p_ds_grid_
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
std::array<index_t, 0>{}, // StrideDs_
StrideC,
StrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale),

View File

@@ -24,7 +24,8 @@ namespace device {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
@@ -32,6 +33,7 @@ template <typename GridwiseGemm,
index_t AK1,
index_t BK1,
GemmSpecialization GemmSpec,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -95,8 +97,22 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
size_a_buffer,
size_b_buffer,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
@@ -106,9 +122,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
@@ -124,9 +140,9 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(CDataType),
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
@@ -149,6 +165,16 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
}
}();
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
// be odd.
constexpr bool AtomicsImplementationExists =
!(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>) ||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
if(has_main_k_block_loop)
{
// Tail number always full
@@ -157,12 +183,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
@@ -186,12 +215,15 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
@@ -229,8 +261,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{

View File

@@ -47,7 +47,7 @@ struct Add
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
y = x0 + type_convert<float>(x1);
};
template <>

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
@@ -22,9 +22,10 @@ namespace ck {
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...)
/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise
// operations that could be applied on each tensor respectively. The CDE_op is an
// elementwise operation applied to the C and all D tensors.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
@@ -36,18 +37,20 @@ namespace ck {
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam DsLayout D tensors data layouts.
/// @tparam ELayout E tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam DsDataType D tensors data types.
/// @tparam EDataType E tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after
/// GEMM) and D input tensors.
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
@@ -105,11 +108,12 @@ namespace ck {
/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access.
/// Used when loading data from D tensors and storing data
/// to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
@@ -123,15 +127,17 @@ namespace ck {
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -161,8 +167,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -173,15 +179,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -211,8 +219,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -223,15 +231,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -261,8 +271,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -297,17 +307,22 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -315,14 +330,16 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -338,11 +355,19 @@ struct GridwiseGemm_wmma_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
<< ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
<< ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -350,7 +375,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -367,21 +393,35 @@ struct GridwiseGemm_wmma_cshuffle_v3
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -396,42 +436,49 @@ struct GridwiseGemm_wmma_cshuffle_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -442,7 +489,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -465,23 +512,32 @@ struct GridwiseGemm_wmma_cshuffle_v3
__device__ static index_t GetKBlockPerScale() { return 1; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -491,8 +547,8 @@ struct GridwiseGemm_wmma_cshuffle_v3
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -508,17 +564,23 @@ struct GridwiseGemm_wmma_cshuffle_v3
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -528,17 +590,21 @@ struct GridwiseGemm_wmma_cshuffle_v3
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -20,15 +20,17 @@ namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale N
@@ -60,11 +62,11 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
@@ -72,15 +74,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -110,8 +114,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -124,15 +128,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
@@ -162,8 +168,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
@@ -198,17 +204,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::MakeDEGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_M_N;
using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using Base::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
using Base::NumDTensor;
using typename Base::DsGridPointer;
struct Problem
{
__host__ Problem(index_t M_,
@@ -216,7 +227,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -224,7 +236,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -241,11 +254,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", ";
if constexpr(NumDTensor > 0)
{
std::cout << "SDs: { ";
static_for<0, NumDTensor, 1>{}([&](auto i) {
std::cout << StrideDs[i] << (i.value < NumDTensor - 1 ? ", " : "");
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
}
index_t M;
@@ -253,7 +275,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -271,30 +294,38 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
EDataType* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
CDEElementwiseOperation cde_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_},
cde_element_op{cde_element_op_},
is_reduce(is_reduce_)
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(p_ds_grid_[i]);
});
}
__host__ __device__ inline bool IsReduceAdd() const
@@ -309,57 +340,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
const CDEElementwiseOperation cde_element_op;
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
@@ -370,7 +402,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{
@@ -454,24 +486,33 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
@@ -487,8 +528,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
@@ -503,17 +544,23 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
EGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
@@ -523,18 +570,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
// NOTE: Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op);
}
};

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -19,7 +19,7 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
@@ -31,17 +31,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg);
#if defined(__gfx11__)
@@ -54,15 +54,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
template <typename ALayout,
typename BLayout,
typename CLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
@@ -92,8 +94,8 @@ template <typename ALayout,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
@@ -112,6 +114,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto EShuffleBlockTransferScalarPerVector =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
@@ -430,17 +435,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
}
template <typename DELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideDE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideDE));
}
}();
@@ -493,6 +499,44 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
#endif
}
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeDEGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
@@ -805,18 +849,18 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
NRepeat,
KPack>())>;
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
const auto de_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
de_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
return de_grid_desc_mblock_mperblock_nblock_nperblock;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
@@ -950,56 +994,51 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.N % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
if(karg.M % EShuffleBlockTransferScalarPerVector != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
"EShuffleBlockTransferScalarPerVector ("
<< EShuffleBlockTransferScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}
}
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
if constexpr(!(is_same<remove_cvref_t<EDataType>, half_t>::value ||
is_same<remove_cvref_t<EDataType>, float>::value ||
is_same<remove_cvref_t<EDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<EDataType>, int32_t>::value))
{
if(!karg.IsReduceAdd())
if(karg.IsAtomicAdd() && karg.KBatch > 1)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
if(karg.KBatch > 1)
{
return false;
std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported for this "
<< "destination type (EDataType) " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}
@@ -1062,19 +1101,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename BScaleStruct,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
DsGridPointer p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
@@ -1084,12 +1130,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
@@ -1330,31 +1379,58 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
m_thread_data_on_block_idx[I3]),
ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor buffers
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); },
Number<NumDTensor>{}));
// blockwise copy which loads C from LDS, D from global, applies elementwise
// operation and stores result E to global
auto cde_shuffle_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock, // ThreadGroup
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, // ElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOps,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1,
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // DstDimAccessOrder,
3, // SrcVectorDim,
3, // DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors, // SrcScalarPerVectors
EShuffleBlockTransferScalarPerVector, // DstScalarPerVector
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
cde_element_op};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
@@ -1370,7 +1446,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
MAccVgprs>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
constexpr auto sfc_cde_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
@@ -1380,7 +1456,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
@@ -1397,20 +1473,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// each block loads its C data from LDS, D from global, applies elementwise
// operation and stores result E to global
cde_shuffle_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_global_step);
});
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on E
cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
}
});
}

View File

@@ -165,6 +165,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3
oob_val = oob_val & is_src_valid;
// TODO: With column-major matrices this step restricts the transferred tensor slice
// to just one element, which consequently prevents using atomic operations if the
// matrix data type is on 16 bits.
if constexpr(SrcScalarPerVectors{}[i] == 1)
{
auto data_types = SrcDatas{};

View File

@@ -270,8 +270,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -390,8 +390,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -793,6 +793,9 @@ struct WmmaGemm
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
// Integer wmma operators need extra input flags to indicate if the input is signed or
// unsigned. At the moment CK supports only signed integer inputs, so these flags are
// hardcoded.
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);

View File

@@ -162,9 +162,15 @@ CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
{
printf("tensor_descriptor{\n");
// first print the tensor adaptor part of the descriptor using the base class print
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
printf("element_space_size_: %ld,\n",
static_cast<long>(descriptor.get_element_space_size().value));
using Base = typename tensor_descriptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
TopDimensionHiddenIds,
ElementSpaceSize,
GuaranteedVectorLengths,
GuaranteedVectorStrides>::Base;
print(static_cast<const Base&>(descriptor));
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
printf("guaranteed_vector_lengths: ");
print(GuaranteedVectorLengths{});
printf(",\nguaranteed_vector_strides: ");

View File

@@ -91,7 +91,7 @@ struct Default2DEpilogue
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* = nullptr)
void* = nullptr) const
{
const auto storeOrUpdateTile = [&](const auto& o_tile) {
// TODO: this is ugly

View File

@@ -103,27 +103,41 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
@@ -131,8 +145,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
lse_lds_ptr0,
lse_lds_ptr1,
d_lds_ptr0,
d_lds_ptr1,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
@@ -156,8 +172,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
LSEDataType* __restrict__ lse_lds_ptr0,
LSEDataType* __restrict__ lse_lds_ptr1,
DDataType* __restrict__ d_lds_ptr0,
DDataType* __restrict__ d_lds_ptr1,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
@@ -389,38 +407,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
@@ -471,27 +489,31 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
OGradDataType* const __restrict__ do_lds_ptr_next,
LSEDataType* const __restrict__ lse_lds_ptr_curr,
LSEDataType* const __restrict__ lse_lds_ptr_next,
DDataType* const __restrict__ d_lds_ptr_curr,
DDataType* const __restrict__ d_lds_ptr_next
) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
if constexpr(is_prologue)
{
lse_block_tile = load_tile(lse_dram_window);
lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
async_load_tile(lse_lds_write_window, lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_block_tile = load_tile(d_dram_window);
d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
async_load_tile(d_lds_write_window, d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
@@ -510,6 +532,13 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
}
if constexpr(is_epilogue)
{
lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
lse = load_tile(lse_lds_read_window);
d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
@@ -617,11 +646,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_prologue)
{
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
}
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
@@ -676,13 +700,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
s_waitcnt</*vmcnt=*/0>();
block_sync_lds();
if constexpr(is_prologue)
{
q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
}
if constexpr(is_epilogue)
{
@@ -720,7 +743,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
{
do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
@@ -749,17 +771,25 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next);
do_lds_ptr_next,
lse_lds_ptr_curr,
lse_lds_ptr_next,
d_lds_ptr_curr,
d_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -363,38 +363,38 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_dram_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto lse_lds_read_window =
make_tile_window(lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto d_lds_read_window =
make_tile_window(d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
@@ -707,18 +707,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
dk_epilogue(dk_dram_window, dk_acc);
dk_epilogue(dk_dram_window, dk_acc, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, dv_acc);
dv_epilogue(dv_dram_window, dv_acc, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
};
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -740,9 +740,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
const auto seqlen_kv_length = k_length.at(number<0>{});
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0}, nullptr);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0}, nullptr);
move_tile_window(dv_dram_window, {kN0, 0});
}
@@ -752,8 +752,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
dq_acc);
else
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
// static_assert(kIsDeterministic);
dq_epilogue(dq_dram_window, dq_acc);
dq_epilogue(dq_dram_window, dq_acc, nullptr);
return;
}
};

View File

@@ -194,13 +194,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentOGrad<Problem>();
return GetTransposedAlignmentX<typename Problem::OGradDataType>();
}
template <typename Problem>
@@ -358,11 +352,30 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::kVHeaddim>();
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
return BlockFmhaBwdPipelineDefaultPolicy::MakeLSEDDramTileDistribution<Problem,
BlockGemm>();
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N0 = MWarp * NWarp;
constexpr index_t M1 = kMPerBlock;
constexpr index_t M0 = get_warp_size() / M1;
static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
"M1 must be a factor of warp size");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, M0>,
tuple<sequence<M1, 1>>,
tuple<sequence<0>, sequence<0, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1>,
sequence<1>>{});
}
template <typename Problem>
@@ -793,9 +806,10 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
return lsed_lds_block_desc;
}
template <typename Problem, typename BlockGemm>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
@@ -984,15 +998,16 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return static_cast<index_t>(max( //
sizeof(int) * get_warp_size(),
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
return sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return GetSmemSizeLSE<Problem>();
}
template <typename Problem>
@@ -1039,8 +1054,9 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 + smem_size_lse +
smem_size_d + max(smem_size_bias, smem_size_ds);
constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
smem_size_lse * 2 + smem_size_d * 2 +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0, smem_size_stage1);
}
@@ -1090,6 +1106,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
@@ -1116,11 +1134,12 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
public:
static constexpr index_t TOTAL_VMEM_READ =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
CK_TILE_DEVICE static constexpr void SchedulerGemm0()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
@@ -1128,7 +1147,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
static_for<0, lcm_inst, 1>{}([&](auto i) {
@@ -1161,8 +1180,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = LSE_LDS_WRITE + D_LDS_WRITE + SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
@@ -1185,7 +1204,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);

View File

@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);

View File

@@ -266,6 +266,10 @@ struct GroupedGemmKernel
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
"SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
@@ -282,11 +286,15 @@ struct GroupedGemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel)
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -296,9 +304,11 @@ struct GroupedGemmKernel
splitk_batch_offset,
i_m,
i_n);
return;
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
@@ -311,14 +321,14 @@ struct GroupedGemmKernel
i_n);
}
}
else
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
@@ -438,17 +448,34 @@ struct GroupedGemmKernel
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run GEMM pipeline with compile-time branching
const auto& c_block_tile = [&]() {
if constexpr(GemmPipeline::Preshuffle)
{
// Preshuffle version - without has_hot_loop parameter
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
else
{
// Regular version - with has_hot_loop parameter
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
@@ -491,8 +518,9 @@ struct GroupedGemmKernel
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,

View File

@@ -43,7 +43,7 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = 0>
bool Preshuffle_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;

View File

@@ -296,6 +296,73 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG_::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"

View File

@@ -769,12 +769,11 @@ struct QuantGemmKernel
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& bq_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{

View File

@@ -0,0 +1,433 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/host.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
/// @brief The Grouped GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
struct QuantGroupedGemmHostArgs
{
CK_TILE_HOST QuantGroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* e_ptr_,
const void* aq_ptr_,
const void* bq_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t QK_A_,
index_t QK_B_,
index_t stride_A_,
index_t stride_B_,
index_t stride_E_,
index_t stride_AQ_,
index_t stride_BQ_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
aq_ptr(aq_ptr_),
bq_ptr(bq_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
QK_A(QK_A_),
QK_B(QK_B_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_AQ(stride_AQ_),
stride_BQ(stride_BQ_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const void* aq_ptr;
const void* bq_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t QK_A;
index_t QK_B;
index_t stride_A;
index_t stride_B;
index_t stride_AQ;
index_t stride_BQ;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
using QuantGroupedGemmKernelArgs = QuantGemmKernelArgs;
struct QuantGemmTransKernelArg
{
QuantGroupedGemmKernelArgs group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
QuantGemmTransKernelArg() = delete;
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
{
}
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
{
}
};
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
QuantType QuantType_>
struct QuantGroupedGemmKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
//// @brief Specify the layout configurations for A, B, C/E
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
using AQDataType =
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
using BQDataType =
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
static constexpr auto kQuantType = QuantType_;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
static_assert(!is_detected<is_tuple, CLayout>::value &&
!is_detected<is_tuple, CDataType>::value,
"C/ELayout and C/EDataType must be scalars.");
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
using Kernel =
QuantGroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
{
return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
{
return group_count * sizeof(QuantGemmTransKernelArg);
}
CK_TILE_HOST static auto BlockSize() -> dim3
{
if(is_wave32())
{
return dim3(kBlockSize / 2);
}
else
{
return dim3(kBlockSize);
}
}
/**
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
* @return The maximum occupancy grid size.
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
*/
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
int occupancy;
HIP_CHECK_ERROR(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
const int grid_size = get_available_compute_units(s) * occupancy;
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
{
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
grid_size += local_grid_size * it_desc.k_batch;
}
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
-> std::vector<QuantGemmTransKernelArg>
{
std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
index_t grid_size = 0;
gemm_kernel_args_.reserve(group_count);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M;
const index_t N = gemm_descs[i].N;
const index_t K = gemm_descs[i].K;
if(M == 0 || N == 0 || K == 0)
{
continue;
}
const index_t stride_a = gemm_descs[i].stride_A;
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_e = gemm_descs[i].stride_C;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
const index_t block_start = grid_size;
const index_t block_end = grid_size + grid_size_grp;
grid_size += grid_size_grp;
auto karg =
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
type_convert<CDataType*>(gemm_descs[i].e_ptr),
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
gemm_descs[i].k_batch,
M,
N,
K,
gemm_descs[i].QK_A,
gemm_descs[i].QK_B,
stride_a,
stride_b,
stride_e,
gemm_descs[i].stride_AQ,
gemm_descs[i].stride_BQ};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
return gemm_kernel_args_;
}
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
{
for(const auto& karg : kargs)
{
if(!Base::IsSupportedArgument(karg.group_karg))
{
return false;
}
}
return true;
}
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
static_assert(GemmPipeline::DoubleSmemBuffer == false,
"DoubleSmemBuffer needs to be false");
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
RunGemmWithPipelineSelection(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
* and the tail-number. This is needed for the persistent tile-loop when
* we didn't have access to the K dimension on the host.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param aq_ptr input AQ pointer
* @param bq_ptr input BQ pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
* batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BQDataType* bq_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const QuantGroupedGemmKernelArgs& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I4);
if constexpr(kQuantType == QuantType::RowColQuant)
{
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
c_block_window,
c_block_tile,
c_block_window,
smem_ptr_0,
aq_block_window,
bq_block_window);
}
}
// For persistent kernels
template <bool U = UsePersistentKernel,
typename = std::enable_if_t<U>,
typename = void> // extra template parameter to avoid redefinition
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const
{
const index_t grid_size = ck_tile::get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
index_t cum_grid_size = 0;
for(index_t group_id = 0; group_id < group_count; ++group_id)
{
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
const auto& k_batch = kargs.k_batch;
const auto block_start = cum_grid_size;
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
while(block_id < cum_grid_size)
{
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)
{
break; // exit the loop if all blocks are processed
}
}
}
}
};
} // namespace ck_tile

View File

@@ -23,8 +23,10 @@ template <bool kPadM_,
typename BLayout_,
typename CLayout_,
QuantType QuantType_,
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_>
typename AQLayout_ = ALayout_,
typename BQLayout_ = BLayout_,
bool DoubleSmemBuffer_ = false,
bool UsePersistentKernel_ = false>
struct TileGemmQuantTraits
{
static constexpr bool kPadM = kPadM_;
@@ -33,7 +35,8 @@ struct TileGemmQuantTraits
static constexpr QuantType kQuantType = QuantType_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = 16;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
using ALayout = ALayout_;
using BLayout = BLayout_;
@@ -44,6 +47,7 @@ struct TileGemmQuantTraits
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};

View File

@@ -54,6 +54,7 @@ using MFMA = ck::tensor_layout::gemm::MFMA;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
using Row_Col_Tuple = ck::Tuple<Row, Col>;
// Conv layout
//

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -41,8 +42,37 @@ void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(
PassThrough,
PassThrough,
Add>>>&);
#endif
// GEMM + Add +
#if defined(CK_USE_WMMA)
void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>&);
void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Add>>>&);
#endif
// GEMM + Add
template <typename ALayout,
typename BLayout,
typename D0Layout,
@@ -52,17 +82,91 @@ template <typename ALayout,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
Add>>
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
Add>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with Add at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
#endif
#endif
return op_ptrs;
}
};
// GEMM + Add
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
Add>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -80,6 +184,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
@@ -104,10 +209,32 @@ struct DeviceOperationInstanceFactory<
}
#endif
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
Add>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,11 +11,13 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#if defined(CK_ENABLE_FP16)
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -67,8 +69,64 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
#endif // CK_USE_XDL
// GEMM + Add + Add + FastGelu
#if defined(CK_USE_WMMA)
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + FastGelu
// DeviceGemmMultipleDSplitK specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
@@ -79,18 +137,100 @@ template <typename ALayout,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
constexpr bool IsAllDRowLayout = is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row>;
constexpr bool IsAllDFloat16 =
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t>;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t> && IsAllDRowLayout && IsAllDFloat16)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + Add + Add + FastGelu
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -100,47 +240,69 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
PassThrough,
PassThrough,
AddAddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
constexpr bool IsAllDRowLayout = is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row>;
constexpr bool IsAllDFloat16 =
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t>;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
is_same_v<EDataType, half_t> && IsAllDRowLayout && IsAllDFloat16)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddAddFastGelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -150,3 +312,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif // CK_ENABLE_FP16

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -93,8 +94,64 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_in
PassThrough,
PassThrough,
AddFastGelu>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + FastGelu
// DeviceGemmMultipleDSplitK specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
@@ -103,18 +160,97 @@ template <typename ALayout,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + Add + FastGelu
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -132,7 +268,8 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -143,7 +280,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
@@ -156,8 +293,9 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -186,6 +324,29 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddFastGelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -70,6 +71,145 @@ void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_
PassThrough,
PassThrough,
AddMultiply>>>&);
#endif
#if defined(CK_USE_WMMA)
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>&);
#endif
// GEMM + Add + Multiply
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddMultiply>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#endif
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Add + Multiply
template <typename ALayout,
@@ -111,6 +251,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
@@ -144,6 +285,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
AddMultiply>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -41,6 +42,35 @@ void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
PassThrough,
PassThrough,
AddRelu>>>&);
#endif
#if defined(CK_USE_WMMA)
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddRelu>>>&);
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddRelu>>>&);
#endif
// GEMM + Add + Relu
template <typename ALayout,
@@ -52,17 +82,92 @@ template <typename ALayout,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddRelu>>
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddRelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddRelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddRelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
#endif
#endif
return op_ptrs;
}
};
// GEMM + Add + Relu
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddRelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -80,6 +185,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
@@ -106,10 +212,32 @@ struct DeviceOperationInstanceFactory<
}
#endif
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddRelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -41,6 +42,107 @@ void add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
PassThrough,
PassThrough,
AddSilu>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddSilu>>>&);
void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddSilu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + Silu
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddSilu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddSilu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// no split-k xdl implementations
#endif // CL_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
op_ptrs);
}
}
#endif
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + Add + Silu
template <typename ALayout,
@@ -78,8 +180,11 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
@@ -105,7 +210,28 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#endif // CL_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType>,
EDataType,
PassThrough,
PassThrough,
AddSilu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -16,7 +16,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16)
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
@@ -68,8 +69,11 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_INT8)
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -121,7 +125,63 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#endif // CK_ENABLE_INT8
#if defined(CK_ENABLE_FP16)
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
@@ -131,18 +191,95 @@ template <typename ALayout,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
PassThrough,
PassThrough,
Bilinear>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
PassThrough,
PassThrough,
Bilinear>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddBilinear at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
{
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
typename DLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
PassThrough,
PassThrough,
Bilinear>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -152,14 +289,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
PassThrough,
PassThrough,
Bilinear>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
@@ -188,8 +326,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
#endif // CK_ENABLE_FP16
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<DLayout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
PassThrough,
PassThrough,
Bilinear>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
// Bilinear wmma i8 instances are using DeviceGemmMultipleD interface.
#if defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{
@@ -214,7 +375,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
}
}
#endif
#endif // CK_ENABLE_INT8
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -67,6 +68,132 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
PassThrough,
PassThrough,
FastGelu>>>&);
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>&);
#endif // CK_USE_WMMA
// GEMM + Add + FastGelu
// DeviceGemmMultipleDSplitK specialization
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// GEMM + FastGelu
template <typename ALayout,
@@ -75,17 +202,17 @@ template <typename ALayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -103,6 +230,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
@@ -127,6 +255,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
FastGelu>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}
@@ -136,4 +286,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#endif // CK_ENABLE_FP16

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -19,6 +19,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_XDL)
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
@@ -71,9 +72,64 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m
PassThrough,
PassThrough,
MultiplyAdd>>>&);
#endif
#endif // CK_ENABLE_FP8
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
#ifdef CK_USE_WMMA_FP8
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>&);
#endif // CK_USE_WMMA_FP8
#endif // CK_USE_WMMA
// GEMM + Multiply + Add
template <typename ALayout,
typename BLayout,
typename D0Layout,
@@ -84,18 +140,107 @@ template <typename ALayout,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyAdd>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
// No XDL instances for DeviceGemmMultipleDSplitK with MultiplyAdd at the moment
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F16> &&
is_same_v<D0DataType, F16> && is_same_v<D1DataType, F16> &&
is_same_v<EDataType, F16>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
#ifdef CK_USE_WMMA_FP8
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F8> &&
is_same_v<D0DataType, F32> && is_same_v<D1DataType, F32> &&
is_same_v<EDataType, F16>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};
// DeviceGemmMultipleD specialization
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename EDataType>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
@@ -105,17 +250,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyAdd>;
PassThrough,
PassThrough,
MultiplyAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
#ifdef CK_USE_XDL
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F16> &&
is_same_v<D0DataType, F16> && is_same_v<D1DataType, F16> &&
is_same_v<EDataType, F16>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
@@ -133,10 +279,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#if defined CK_ENABLE_FP8
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
is_same_v<EDataType, half_t>)
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F8> &&
is_same_v<D0DataType, F32> && is_same_v<D1DataType, F32> &&
is_same_v<EDataType, F16>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
@@ -153,7 +299,29 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_FP8
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
// Reuse DeviceGemmMultipleDSplitK instances
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
PassThrough,
PassThrough,
MultiplyAdd>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -16,6 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
@@ -199,7 +200,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif // CK_ENABLE_BF16
#ifdef CK_ENABLE_FP16
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
@@ -278,8 +279,8 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif
#endif // CK_ENABLE_FP16
#endif // CK_ENABLE_FP8
#ifdef CK_ENABLE_FP16
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
@@ -463,7 +464,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif // CK_ENABLE_FP16
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
@@ -544,7 +545,62 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_in
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif // CK_ENABLE_FP16 || CK_ENABLE_INT8
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
F8,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
F8,
F8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif // CK_USE_WMMA
template <typename ADataType,
typename BDataType,
@@ -553,36 +609,35 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
PassThrough,
PassThrough,
MultiplyMultiply>>
{
using DeviceOp =
DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
PassThrough,
PassThrough,
MultiplyMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
@@ -624,7 +679,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#endif // CK_ENABLE_BF16
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
@@ -665,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
#endif
#endif
#endif // CK_ENABLE_FP16
#endif // CK_ENABLE_FP8
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
@@ -691,6 +746,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -75,7 +75,7 @@ function(add_instance_library INSTANCE_NAME)
endif()
# Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -117,13 +117,13 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
endif()
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
@@ -136,7 +136,7 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
endif()
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
@@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
set(add_inst 0)
endif()

View File

@@ -1,5 +1,8 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_instance
device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,10 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_add_fastgelu_instance
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,9 +1,14 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_fastgelu_instance
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,77 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,11 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_multiply_instance
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,8 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_relu_instance
device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,6 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_silu_instance
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
)

Some files were not shown because too many files have changed in this diff Show More