Merge commit 'a44bea45b205a84552e417a7b069d962d73c6cb1' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-26 17:11:27 +00:00
parent 77dcfaa687
commit dd38b01ac5
16 changed files with 1527 additions and 216 deletions

View File

@@ -1,10 +1,12 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -1,140 +1,8 @@
# Grouped Gemm
Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.
## Quick Tour for New Users
The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Key Arguments
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure
```cpp
// Example
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
```
In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`.
```cpp
// Example
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
```
Where:
- `Ms` is the M dimension of each GEMM.
- `Ns` is the N dimension of each GEMM.
- `Ks` is the K dimension of each GEMM.
- `stride_As` is the stride values for matrix A.
- `stride_Bs` is the stride values for matrix B.
- `stride_Cs` is the stride values for matrix C.
### HostTensor and Device Memory Buffers (for CPU and GPU)
Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations.
The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where:
- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.
#### HostTensor Buffers (for CPU)
In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:
```cpp
// Example
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
```
Where:
- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation.
- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`).
- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`).
The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`.
#### Device Memory Buffers (for GPU)
Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation..
```cpp
// Example
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
```
Where:
- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU.
- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU.
- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU.
## Prepare data
In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.
Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant<is_row_major>`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.
```cpp
// Example, API
template <bool is_row_major>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
// code
}
```
Where:
- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
- `row` is the number of rows in the matrix.
- `col` is the number of columns in the matrix.
- `stride` is the current stride (the distance between consecutive elements in memory).
- `bool_constant<is_row_major>` is a tag type that helps in differentiating behavior at compile-time.
Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.
```cpp
// Example for tensor A
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))
```
After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`.
The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector<grouped_gemm_kargs> gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as:
```cpp
struct GroupedGemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
```
The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`:
```cpp
// Example, API
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
```
`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:
```cpp
// Example
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
### Advanced Features: Preshuffle and Persistence
### Preshuffle and Persistence
The grouped GEMM examples include two advanced optimization features:
@@ -153,17 +21,17 @@ Persistence mode is a GPU optimization where thread blocks remain active on the
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
#### Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
```
All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
- **Data Types**: Supports fp16
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
## Build
```
@@ -175,10 +43,13 @@ mkdir build && cd build
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The multi-D operations example
make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
## example
```
@@ -213,4 +84,4 @@ K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
```
```

View File

@@ -9,7 +9,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
@@ -296,7 +295,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
@@ -325,7 +324,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
}
template <typename GemmConfig, typename T>

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm_multi_d.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_grouped_gemm_multi_d_example.inc"
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_multi_d_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_grouped_gemm_multi_d_example<GemmConfigV3>(argc, argv) ||
!run_grouped_gemm_multi_d_example<GemmConfigMemory>(argc, argv) ||
!run_grouped_gemm_multi_d_example<GemmConfigV4>(argc, argv);
#endif
}

View File

@@ -0,0 +1,220 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet
static constexpr bool Persistent = false; // currently persistent == true is not supported yet
static constexpr bool DoubleSmemBuffer =
false; // currently double smem buffer == true is not supported yet
};
struct GemmConfigMemory : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3 : public GemmConfigBase
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4 : public GemmConfigBase
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma : public GemmConfigBase
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("Ns", "", "N dimensions - empty by default.")
.insert("Ks", "", "K dimensions - empty by default.")
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("stride_Ds", "", "Tensor Ds strides - it is empty by default.")
.insert("stride_Es", "", "Tensor E strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("ds_layout", "R", "Ds tensor data layout - Row by default.")
.insert("e_layout", "R", "E tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp16", "data type. fp16")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
.insert("kbatch", "1", "kbatch for SplitK")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>);
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);

View File

@@ -88,7 +88,7 @@ float invoke_gemm(int n_warmup,
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::GemmTransKernelArg> kargs;
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
@@ -109,7 +109,7 @@ float invoke_gemm(int n_warmup,
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<GemmConfig,
@@ -260,8 +260,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
gemm_descs.push_back({p_a,
p_b,
{/*ds_ptr*/},
p_c,
kbatch,
M,
N,
K,
stride_As[i],
stride_Bs[i],
{/*stride_Ds*/},
stride_Cs[i]});
}
float ave_time = invoke_gemm<GemmConfig,

View File

@@ -0,0 +1,389 @@
#pragma once
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_multi_d_kargs>& args)
{
// Workspace memory allocated to hold the gemm descriptions.
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
if constexpr(!GemmConfig::Persistent)
{
ave_time = grouped_gemm_multi_d<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
}
else
{
(void)group_count;
// not supported yet
throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet");
}
return ave_time;
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout>
int run_grouped_gemm_multi_d_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const D0Layout d0_layout = D0Layout{},
const D1Layout d1_layout = D1Layout{},
const ELayout e_layout = ELayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
using CDElementWise = MultiplyMultiply;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
const int kbatch = arg_parser.get_int("kbatch");
bool validate = arg_parser.get_bool("validate");
if(kbatch > 1 && validate && warmup + repeat > 1)
{
std::cout << "WARNING: Data validation enabled with SplitK and more than"
<< "1 warmup/repeat. Disabling validation." << std::endl;
validate = false;
}
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_D0 = arg_parser.get_int_vec("stride_Ds");
std::vector<ck_tile::index_t> stride_D1 = arg_parser.get_int_vec("stride_Ds");
std::vector<ck_tile::index_t> stride_Es = arg_parser.get_int_vec("stride_Es");
if(!valid_input_data(
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_D0, stride_D1, stride_Es))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, "
"896, 1280..), stride_As (Ks), stride_Bs (Ks), stride_D0 (Ns), stride_D1 "
"(Ns), stride_Es (Ns)"
<< std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 /* + 256 * i */);
Ns.push_back(256 /* + 512 * i */);
Ks.push_back(64 /* + 384 * i */);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_D0.push_back(Ns[i]);
stride_D1.push_back(Ns[i]);
stride_Es.push_back(Ns[i]);
}
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<D0DataType>> d0_m_n_tensors;
std::vector<ck_tile::HostTensor<D1DataType>> d1_m_n_tensors;
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
d0_m_n_tensors.reserve(group_count);
d1_m_n_tensors.reserve(group_count);
e_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d0_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d1_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> e_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
d0_m_n_dev_buf.reserve(group_count);
d1_m_n_dev_buf.reserve(group_count);
e_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_multi_d_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_D0[i] = ck_tile::get_default_stride(M, N, stride_D0[i], is_row_major(d0_layout));
stride_D1[i] = ck_tile::get_default_stride(M, N, stride_D1[i], is_row_major(d1_layout));
stride_Es[i] = ck_tile::get_default_stride(M, N, stride_Es[i], is_row_major(e_layout));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
d0_m_n_tensors.push_back(ck_tile::HostTensor<D0DataType>(
ck_tile::host_tensor_descriptor(M, N, stride_D0[i], is_row_major(d0_layout))));
d1_m_n_tensors.push_back(ck_tile::HostTensor<D1DataType>(
ck_tile::host_tensor_descriptor(M, N, stride_D1[i], is_row_major(d1_layout))));
e_m_n_tensors.push_back(ck_tile::HostTensor<EDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_Es[i], is_row_major(e_layout))));
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " d0_m_n: " << d0_m_n_tensors[i].mDesc
<< " d1_m_n: " << d1_m_n_tensors[i].mDesc << " e_m_n: " << e_m_n_tensors[i].mDesc
<< std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<D0DataType>{2.f, -2.f}(d0_m_n_tensors[i]);
ck_tile::FillUniformDistribution<D1DataType>{2.f, -2.f}(d1_m_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_k_n_tensors[i]));
d0_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(d0_m_n_tensors[i]));
d1_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(d1_m_n_tensors[i]));
e_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(e_m_n_tensors[i]));
e_m_n_dev_buf[i]->SetZero();
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {
d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {stride_D0[i], stride_D1[i]};
gemm_descs.push_back({p_a,
p_b,
ds_ptr_buf,
p_e,
kbatch,
M,
N,
K,
stride_As[i],
stride_Bs[i],
stridesDs,
stride_Es[i]});
}
float ave_time = invoke_gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDElementWise>(warmup, repeat, group_count, gemm_descs);
std::string op_name{"Grouped Gemm Multiple-D"};
std::size_t flop = 0, num_btype = 0;
for(int j = 0; j < group_count; ++j)
{
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) {
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_descs[j].M * gemm_descs[j].N;
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_descs[j].M * gemm_descs[j].N;
});
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
sizeof(EDataType) * gemm_descs[j].M * gemm_descs[j].N;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_host_refs;
e_m_n_host_refs.reserve(group_count);
// copy e_m_n_tensors result from device to host and initialize host tensors to zero
for(int i = 0; i < group_count; i++)
{
e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data());
}
bool pass{true};
if(validate)
{
for(int i = 0; i < group_count; ++i)
{
e_m_n_host_refs.push_back(ck_tile::HostTensor<EDataType>(
host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], is_row_major(e_layout))));
e_m_n_host_refs[i].SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
CDElementWise>(
a_m_k_tensors[i],
b_k_n_tensors[i],
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
e_m_n_host_refs[i]);
std::cout << "e_m_n_host_refs[i]: " << std::endl;
e_m_n_host_refs[i].print_first_n(std::cout, 10);
std::cout << std::endl;
std::cout << "e_m_n_tensors[i]: " << std::endl;
e_m_n_tensors[i].print_first_n(std::cout, 10);
std::cout << std::endl;
const float max_accumulated_value =
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
pass &=
ck_tile::check_err(e_m_n_tensors[i],
e_m_n_host_refs[i],
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_grouped_gemm_json_results<ALayout, BLayout, ELayout>(arg_parser.get_str("jsonfile"),
op_name,
group_count,
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass;
}
template <typename GemmConfig>
int run_grouped_gemm_multi_d_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string ds_layout = arg_parser.get_str("ds_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
{
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig>(
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
}
}

View File

@@ -23,10 +23,13 @@ namespace ck_tile {
/// arguments object. It contain all necessary information required to build proper kernel
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
/// stating all required information like M,N,K sizes and respective strides.
template <index_t NumDTensor = 0>
struct GroupedGemmHostArgs
{
CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
@@ -34,15 +37,18 @@ struct GroupedGemmHostArgs
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
@@ -50,6 +56,7 @@ struct GroupedGemmHostArgs
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
@@ -61,7 +68,7 @@ struct GroupedGemmHostArgs
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
@@ -71,20 +78,23 @@ struct GroupedGemmHostArgs
index_t k_batch;
};
template <index_t NumDTensor = 0>
struct GemmTransKernelArg
{
UniversalGemmKernelArgs<> group_karg;
UniversalGemmKernelArgs<1, 1, NumDTensor> group_karg;
ck_tile::index_t block_start;
ck_tile::index_t block_end;
GemmTransKernelArg() = delete;
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg,
index_t bl_start,
index_t bl_end)
: group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end}
{
}
GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg)
: group_karg{karg}, block_start{0}, block_end{0}
GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg)
: group_karg{std::move(karg)}, block_start{0}, block_end{0}
{
}
};
@@ -106,9 +116,12 @@ struct GroupedGemmKernel
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
/// @brief Specify the data type configurations for A, B, C/E
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t NumDTensor_ = DsDataType::size();
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(
@@ -140,19 +153,21 @@ struct GroupedGemmKernel
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
(UsePersistentKernel ? "Persistent" : "NonPersistent"),
(NumDTensor_ == 2 ? "MultiD" : "NoMultiD"),
(GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer"));
// clang-format on
}
CK_TILE_HOST static auto
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs<>>& gemm_descs) -> std::size_t
{
return gemm_descs.size() * sizeof(GemmTransKernelArg);
return gemm_descs.size() * sizeof(GemmTransKernelArg<NumDTensor_>);
}
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
{
return group_count * sizeof(GemmTransKernelArg);
return group_count * sizeof(GemmTransKernelArg<NumDTensor_>);
}
CK_TILE_HOST static auto BlockSize() -> dim3
@@ -184,7 +199,8 @@ struct GroupedGemmKernel
return dim3(grid_size, 1, 1);
}
CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
CK_TILE_HOST static auto
GridSize(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
{
index_t grid_size = 0;
for(const auto& it_desc : gemm_descs)
@@ -196,9 +212,10 @@ struct GroupedGemmKernel
}
CK_TILE_HOST static auto
MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
MakeKargs(const std::vector<GroupedGemmHostArgs<NumDTensor_>>& gemm_descs)
-> std::vector<GemmTransKernelArg<NumDTensor_>>
{
std::vector<GemmTransKernelArg> gemm_kernel_args_;
std::vector<GemmTransKernelArg<NumDTensor_>> gemm_kernel_args_;
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
index_t grid_size = 0;
gemm_kernel_args_.reserve(group_count);
@@ -217,6 +234,7 @@ struct GroupedGemmKernel
const index_t stride_a = gemm_descs[i].stride_A;
const index_t stride_b = gemm_descs[i].stride_B;
const index_t stride_e = gemm_descs[i].stride_E;
auto stride_ds = gemm_descs[i].stride_Ds;
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
@@ -225,19 +243,19 @@ struct GroupedGemmKernel
grid_size += grid_size_grp;
auto karg =
UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{/*ds_ptr*/},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{/*stride_ds*/},
stride_e,
gemm_descs[i].k_batch};
auto karg = UniversalGemmKernelArgs<1, 1, NumDTensor_>{
{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
{type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
{gemm_descs[i].ds_ptr},
type_convert<CDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
stride_ds,
stride_e,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -245,7 +263,8 @@ struct GroupedGemmKernel
return gemm_kernel_args_;
}
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
CK_TILE_HOST static bool
IsSupportedArgument(const std::vector<GemmTransKernelArg<NumDTensor_>>& kargs)
{
for(const auto& karg : kargs)
{
@@ -262,7 +281,7 @@ struct GroupedGemmKernel
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs,
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
@@ -292,8 +311,16 @@ struct GroupedGemmKernel
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemmWithPipelineSelection2LDS(
a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n);
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
kargs.ds_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // SingleSmemBuffer
{
@@ -306,7 +333,7 @@ struct GroupedGemmKernel
{
Base::RunGemm({a_ptr},
{b_ptr},
{/*ds_ptr*/},
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
@@ -340,7 +367,7 @@ struct GroupedGemmKernel
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const UniversalGemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -396,9 +423,10 @@ struct GroupedGemmKernel
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const UniversalGemmKernelArgs<>& kargs,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
const typename Base::SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
@@ -406,7 +434,7 @@ struct GroupedGemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
@@ -453,7 +481,7 @@ struct GroupedGemmKernel
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg<NumDTensor_>* gemm_desc_ptr,
index_t block_id,
index_t group_count) const
{
@@ -485,7 +513,7 @@ struct GroupedGemmKernel
index_t group_count) const
{
const index_t block_id = ck_tile::get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
@@ -508,7 +536,7 @@ struct GroupedGemmKernel
const index_t group_count) const
{
const index_t grid_size = ck_tile::get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg<NumDTensor_>*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
index_t cum_grid_size = 0;

View File

@@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(grouped_gemm_preshuffle)
add_subdirectory(grouped_gemm_multi_d)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_multi_abd)
add_subdirectory(gemm_streamk)

View File

@@ -116,19 +116,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
template <typename GemmConfig, bool PadM, bool PadN, bool PadK, bool Preshuffle>
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// TODO: This should be parameterized in tests
// constexpr ck_tile::index_t M_Tile = 128;
// constexpr ck_tile::index_t N_Tile = 128;
// constexpr ck_tile::index_t K_Tile = 128;
// constexpr ck_tile::index_t M_Warp = 1;
// constexpr ck_tile::index_t N_Warp = 4;
// constexpr ck_tile::index_t K_Warp = 1;
// constexpr ck_tile::index_t M_Warp_Tile = 32;
// constexpr ck_tile::index_t N_Warp_Tile = 32;
// constexpr ck_tile::index_t K_Warp_Tile = sizeof(ADataType) == 2 ? 16 : 32;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;

View File

@@ -62,10 +62,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
}
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
@@ -436,8 +436,18 @@ class TestCkTileGroupedGemm : public ::testing::Test
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
gemm_descs.push_back({p_a,
p_b,
{/*ds_ptr*/},
p_c,
kbatch,
M,
N,
K,
stride_As[i],
stride_Bs[i],
{/*stride_Ds*/},
stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
@@ -446,7 +456,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
if constexpr(Persistent)
{
// Generate kernel arguments
std::vector<ck_tile::GemmTransKernelArg> kargs;
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = gemm_descs[0].k_batch > 1;
for(const auto& arg : gemm_descs)
@@ -468,7 +478,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
ck_tile::hip_check_error(
hipMemcpyWithStream(kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
hipMemcpyHostToDevice,
stream.stream_id_));
#if CK_TILE_USE_WMMA

View File

@@ -0,0 +1,9 @@
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx94|gfx95")
add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp)
target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_multi_d_util.hpp"
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
// Custom tuple-like structure for kernel configuration
template <typename ALayout_,
typename BLayout_,
typename ELayout_,
typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename EDataType_,
int M_Tile_val_,
int N_Tile_val_,
int K_Tile_val_,
int M_Warp_val_,
int N_Warp_val_,
int K_Warp_val_,
int M_Warp_Tile_val_,
int N_Warp_Tile_val_,
int K_Warp_Tile_val_,
bool DoubleSmemBuffer_val_,
ck_tile::GemmPipelineScheduler Scheduler_val_,
PipelineType Pipeline_val_>
struct KernelConfig
{
using ALayoutType = ALayout_;
using BLayoutType = BLayout_;
using ELayoutType = ELayout_;
using DsLayoutType = ck_tile::tuple<Row, Row>;
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using EDataType = EDataType_;
using DsDataType = ck_tile::tuple<F16, F16>;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
static constexpr int K_Tile_ = K_Tile_val_;
static constexpr int M_Warp_ = M_Warp_val_;
static constexpr int N_Warp_ = N_Warp_val_;
static constexpr int K_Warp_ = K_Warp_val_;
static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_;
static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_;
static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_;
static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_;
static constexpr auto Scheduler_ = Scheduler_val_;
static constexpr PipelineType Pipeline_ = Pipeline_val_;
static constexpr int BlockPerCu_ = 1;
};
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4
>;
// clang-format on
TYPED_TEST_SUITE(TestCkTileGroupedGemmMultiD, KernelTypes);
#include "test_grouped_gemm_multi_d_ut_cases.inc"

View File

@@ -0,0 +1,91 @@
#pragma once
TYPED_TEST(TestCkTileGroupedGemmMultiD, K256)
{
const int group_count = 7;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Es;
std::vector<int> stride_D0;
std::vector<int> stride_D1;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 256 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Es.push_back(Ns[i]);
stride_D0.push_back(Ns[i]);
stride_D1.push_back(Ns[i]);
}
this->Run(
Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count);
}
TYPED_TEST(TestCkTileGroupedGemmMultiD, K128)
{
const int group_count = 5;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Es;
std::vector<int> stride_D0;
std::vector<int> stride_D1;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Es.push_back(Ns[i]);
stride_D0.push_back(Ns[i]);
stride_D1.push_back(Ns[i]);
}
this->Run(
Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count);
}
TYPED_TEST(TestCkTileGroupedGemmMultiD, LargeMNK_8Groups)
{
const int group_count = 8;
const int kbatch = 1;
std::vector<int> Ms;
std::vector<int> Ns;
std::vector<int> Ks;
std::vector<int> stride_As;
std::vector<int> stride_Bs;
std::vector<int> stride_Es;
std::vector<int> stride_D0;
std::vector<int> stride_D1;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(512 + 256 * i);
Ns.push_back(512 + 256 * i);
Ks.push_back(768 + 256 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Es.push_back(Ns[i]);
stride_D0.push_back(Ns[i]);
stride_D1.push_back(Ns[i]);
}
this->Run(
Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count);
}

View File

@@ -0,0 +1,431 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
enum class PipelineType
{
Memory = 0,
CompV3 = 1,
CompV4 = 2
};
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
template <typename Config>
class TestCkTileGroupedGemmMultiD : public ::testing::Test
{
protected:
using ALayout = typename Config::ALayoutType;
using BLayout = typename Config::BLayoutType;
using ELayout = typename Config::ELayoutType;
using DsLayout = typename Config::DsLayoutType;
using ADataType = typename Config::ADataType;
using BDataType = typename Config::BDataType;
using AccDataType = typename Config::AccDataType;
using EDataType = typename Config::EDataType;
using PrecType = BDataType;
using DsDataType = typename Config::DsDataType;
using D0DataType = std::tuple_element_t<0, DsDataType>;
using D1DataType = std::tuple_element_t<1, DsDataType>;
using D0Layout = std::tuple_element_t<0, DsLayout>;
using D1Layout = std::tuple_element_t<1, DsLayout>;
static const bool kPadM = false;
static const bool kPadN = false;
static const bool kPadK = false;
static constexpr bool TransposeC = false; // transpose c is not supported
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType = std::
conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
}
template <typename ALayout, typename BLayout, typename ELayout>
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<Config::M_Tile_, Config::N_Tile_, Config::K_Tile_>,
ck_tile::sequence<Config::M_Warp_, Config::N_Warp_, Config::K_Warp_>,
ck_tile::sequence<Config::M_Warp_Tile_, Config::N_Warp_Tile_, Config::K_Warp_Tile_>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
// for testing purposes, we can hardcode the values here as we what is compatible with
// pipeline
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
Config::DoubleSmemBuffer_,
ALayout,
BLayout,
ELayout,
TransposeC,
/*UseStructuredSparsity*/ false,
/*Persistent*/ false,
/*NumWaveGroups*/ 1,
/*Preshuffle*/ false>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_;
const ck_tile::index_t K_split =
(gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile_;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
Config::Scheduler_,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = std::conditional_t<
Config::Pipeline_ == (PipelineType::Memory),
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>,
std::conditional_t<Config::Pipeline_ == (PipelineType::CompV3),
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>,
ck_tile::GemmPipelineAgBgCrCompV4<UniversalGemmProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
MultiplyMultiply,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Config::M_Warp_,
Config::N_Warp_,
Config::M_Warp_Tile_,
Config::N_Warp_Tile_,
Config::K_Warp_Tile_,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<Config::BlockPerCu_>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
std::vector<int>& stride_As,
std::vector<int>& stride_Bs,
std::vector<int>& stride_Es,
std::vector<int>& stride_D0,
std::vector<int>& stride_D1,
const int kbatch = 1,
const int group_count = 16)
{
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_tensors;
std::vector<ck_tile::HostTensor<D0DataType>> d0_m_n_tensors;
std::vector<ck_tile::HostTensor<D1DataType>> d1_m_n_tensors;
a_m_k_tensors.reserve(group_count);
b_k_n_tensors.reserve(group_count);
e_m_n_tensors.reserve(group_count);
d0_m_n_tensors.reserve(group_count);
d1_m_n_tensors.reserve(group_count);
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> e_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d0_m_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d1_m_n_dev_buf;
a_m_k_dev_buf.reserve(group_count);
b_k_n_dev_buf.reserve(group_count);
e_m_n_dev_buf.reserve(group_count);
d0_m_n_dev_buf.reserve(group_count);
d1_m_n_dev_buf.reserve(group_count);
std::vector<grouped_gemm_kargs> gemm_descs;
gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{});
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
stride_Es[i] = f_get_default_stride(M, N, stride_Es[i], ELayout{});
stride_D0[i] = f_get_default_stride(M, N, stride_D0[i], D0Layout{});
stride_D1[i] = f_get_default_stride(M, N, stride_D1[i], D1Layout{});
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
e_m_n_tensors.push_back(ck_tile::HostTensor<EDataType>(
f_host_tensor_descriptor(M, N, stride_Es[i], ELayout{})));
d0_m_n_tensors.push_back(ck_tile::HostTensor<D0DataType>(
f_host_tensor_descriptor(M, N, stride_D0[i], D0Layout{})));
d1_m_n_tensors.push_back(ck_tile::HostTensor<D1DataType>(
f_host_tensor_descriptor(M, N, stride_D1[i], D1Layout{})));
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
<< " b_k_n: " << b_k_n_tensors[i].mDesc
<< " e_m_n: " << e_m_n_tensors[i].mDesc
<< " d0_m_n: " << d0_m_n_tensors[i].mDesc
<< " d1_m_n: " << d1_m_n_tensors[i].mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<D0DataType>{-2.f, 2.f}(d0_m_n_tensors[i]);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
e_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
e_m_n_tensors[i].get_element_space_size_in_bytes()));
d0_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
d0_m_n_tensors[i].get_element_space_size_in_bytes()));
d1_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
d1_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
e_m_n_dev_buf[i]->SetZero();
d0_m_n_dev_buf[i]->ToDevice(d0_m_n_tensors[i].data());
d1_m_n_dev_buf[i]->ToDevice(d1_m_n_tensors[i].data());
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {
d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {stride_D0[i],
stride_D1[i]};
gemm_descs.push_back({p_a,
p_b,
ds_ptr_buf,
p_e,
kbatch,
M,
N,
K,
stride_As[i],
stride_Bs[i],
stridesDs,
stride_Es[i]});
}
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, ELayout>(gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
// Copy results back to host for validation
for(int i = 0; i < group_count; i++)
{
e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data());
}
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_host_refs;
e_m_n_host_refs.reserve(group_count);
bool pass{true};
for(int i = 0; i < group_count; ++i)
{
e_m_n_host_refs.push_back(ck_tile::HostTensor<EDataType>(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], ELayout{})));
e_m_n_host_refs[i].SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
MultiplyMultiply>(
a_m_k_tensors[i],
b_k_n_tensors[i],
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
e_m_n_host_refs[i]);
const float max_accumulated_value =
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
pass &=
ck_tile::check_err(e_m_n_tensors[i],
e_m_n_host_refs[i],
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
EXPECT_TRUE(pass);
}
};

View File

@@ -88,10 +88,10 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
}
template <typename T>
@@ -333,8 +333,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
gemm_descs.push_back({p_a,
p_b,
{/*ds_ptr*/},
p_c,
kbatch,
M,
N,
K,
stride_As[i],
stride_Bs[i],
{/*stride_Ds*/},
stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;