mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Integrate Multi D GEMMs into Grouped GEMMs along with unit tests (#2923)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature * feat: generalized grouped_gemm_kernel.hpp * feat: generalized grouped_gemm_kernel.hpp even further by removing hardcoded 0 * refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel * tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments * fix: segfault fix by passing correct parameters for d tensors * docs: add multi d info and trim down outdated content * tests: add unit tests for grouped_gemm_multi_d and minor changes in grouped_gemm related test for compatibility * style: clang format * fix: incorrect validation method and Dtensor layout in test suite
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
|
||||
add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
@@ -1,140 +1,8 @@
|
||||
# Grouped Gemm
|
||||
|
||||
Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.
|
||||
|
||||
## Quick Tour for New Users
|
||||
|
||||
The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
|
||||
|
||||
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
|
||||
|
||||
### Key Arguments
|
||||
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group
|
||||
- `repeat`: the number of times to repeat the kernel for benchmarking
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
```
|
||||
In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`.
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
```
|
||||
Where:
|
||||
- `Ms` is the M dimension of each GEMM.
|
||||
- `Ns` is the N dimension of each GEMM.
|
||||
- `Ks` is the K dimension of each GEMM.
|
||||
- `stride_As` is the stride values for matrix A.
|
||||
- `stride_Bs` is the stride values for matrix B.
|
||||
- `stride_Cs` is the stride values for matrix C.
|
||||
|
||||
### HostTensor and Device Memory Buffers (for CPU and GPU)
|
||||
Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations.
|
||||
The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where:
|
||||
- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
|
||||
- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.
|
||||
|
||||
#### HostTensor Buffers (for CPU)
|
||||
In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
```
|
||||
Where:
|
||||
- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation.
|
||||
- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`).
|
||||
- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`).
|
||||
|
||||
The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`.
|
||||
|
||||
#### Device Memory Buffers (for GPU)
|
||||
Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation..
|
||||
```cpp
|
||||
// Example
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
```
|
||||
Where:
|
||||
- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU.
|
||||
- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU.
|
||||
- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU.
|
||||
|
||||
## Prepare data
|
||||
In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.
|
||||
|
||||
Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant<is_row_major>`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.
|
||||
|
||||
```cpp
|
||||
// Example, API
|
||||
template <bool is_row_major>
|
||||
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
|
||||
// code
|
||||
}
|
||||
```
|
||||
|
||||
Where:
|
||||
- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
|
||||
- `row` is the number of rows in the matrix.
|
||||
- `col` is the number of columns in the matrix.
|
||||
- `stride` is the current stride (the distance between consecutive elements in memory).
|
||||
- `bool_constant<is_row_major>` is a tag type that helps in differentiating behavior at compile-time.
|
||||
|
||||
Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.
|
||||
|
||||
```cpp
|
||||
// Example for tensor A
|
||||
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))
|
||||
```
|
||||
|
||||
After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`.
|
||||
|
||||
The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector<grouped_gemm_kargs> gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as:
|
||||
```cpp
|
||||
struct GroupedGemmHostArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
```
|
||||
The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`:
|
||||
```cpp
|
||||
// Example, API
|
||||
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
```
|
||||
`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
```
|
||||
|
||||
### Advanced Features: Preshuffle and Persistence
|
||||
### Preshuffle and Persistence
|
||||
|
||||
The grouped GEMM examples include two advanced optimization features:
|
||||
|
||||
@@ -153,17 +21,17 @@ Persistence mode is a GPU optimization where thread blocks remain active on the
|
||||
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
|
||||
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
#### Multi-D Operations
|
||||
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
|
||||
|
||||
Finally the arguments are passed to group_gemm and the kernel is launched.
|
||||
```cpp
|
||||
// API
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
```
|
||||
All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.
|
||||
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
|
||||
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
|
||||
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
|
||||
- **Data Types**: Supports fp16
|
||||
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
|
||||
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
|
||||
## Build
|
||||
```
|
||||
@@ -175,10 +43,13 @@ mkdir build && cd build
|
||||
make tile_example_grouped_gemm -j
|
||||
# The preshuffle example
|
||||
make tile_example_grouped_gemm_preshuffle -j
|
||||
# The multi-D operations example
|
||||
make tile_example_grouped_gemm_multi_d -j
|
||||
# The quant grouped gemm fp8 example
|
||||
make tile_example_quant_grouped_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
|
||||
|
||||
|
||||
## example
|
||||
```
|
||||
@@ -213,4 +84,4 @@ K[i] = 512 + 384 * i
|
||||
stride_A[i] = K[i]
|
||||
stride_B[i] = K[i]
|
||||
stride_C[i] = N[i]
|
||||
```
|
||||
```
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
@@ -296,7 +295,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>;
|
||||
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -325,7 +324,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>);
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
180
example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp
Normal file
180
example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp
Normal file
@@ -0,0 +1,180 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm_multi_d.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { "
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_multi_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_gemm_multi_d_example<GemmConfigV3_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_gemm_multi_d_example<GemmConfigV3>(argc, argv) ||
|
||||
!run_grouped_gemm_multi_d_example<GemmConfigMemory>(argc, argv) ||
|
||||
!run_grouped_gemm_multi_d_example<GemmConfigV4>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
220
example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp
Normal file
220
example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp
Normal file
@@ -0,0 +1,220 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using AccDataType = float;
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet
|
||||
static constexpr bool Persistent = false; // currently persistent == true is not supported yet
|
||||
static constexpr bool DoubleSmemBuffer =
|
||||
false; // currently double smem buffer == true is not supported yet
|
||||
};
|
||||
|
||||
struct GemmConfigMemory : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
struct GemmConfigV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma : public GemmConfigBase
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>;
|
||||
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Ds", "", "Tensor Ds strides - it is empty by default.")
|
||||
.insert("stride_Es", "", "Tensor E strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default.")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp16", "data type. fp16")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
@@ -88,7 +88,7 @@ float invoke_gemm(int n_warmup,
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
@@ -109,7 +109,7 @@ float invoke_gemm(int n_warmup,
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
@@ -260,8 +260,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
{/*ds_ptr*/},
|
||||
p_c,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
{/*stride_Ds*/},
|
||||
stride_Cs[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
#pragma once
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
|
||||
{
|
||||
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
|
||||
ck_tile::type_convert<float>(d1);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_multi_d_kargs>& args)
|
||||
{
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
ave_time = grouped_gemm_multi_d<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
(void)group_count;
|
||||
// not supported yet
|
||||
throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet");
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
using CDElementWise = MultiplyMultiply;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_D0 = arg_parser.get_int_vec("stride_Ds");
|
||||
std::vector<ck_tile::index_t> stride_D1 = arg_parser.get_int_vec("stride_Ds");
|
||||
std::vector<ck_tile::index_t> stride_Es = arg_parser.get_int_vec("stride_Es");
|
||||
|
||||
if(!valid_input_data(
|
||||
group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_D0, stride_D1, stride_Es))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, "
|
||||
"896, 1280..), stride_As (Ks), stride_Bs (Ks), stride_D0 (Ns), stride_D1 "
|
||||
"(Ns), stride_Es (Ns)"
|
||||
<< std::endl;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 /* + 256 * i */);
|
||||
Ns.push_back(256 /* + 512 * i */);
|
||||
Ks.push_back(64 /* + 384 * i */);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_D0.push_back(Ns[i]);
|
||||
stride_D1.push_back(Ns[i]);
|
||||
stride_Es.push_back(Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<D0DataType>> d0_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<D1DataType>> d1_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
d0_m_n_tensors.reserve(group_count);
|
||||
d1_m_n_tensors.reserve(group_count);
|
||||
e_m_n_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d0_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> d1_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> e_m_n_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
d0_m_n_dev_buf.reserve(group_count);
|
||||
d1_m_n_dev_buf.reserve(group_count);
|
||||
e_m_n_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_multi_d_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
|
||||
stride_D0[i] = ck_tile::get_default_stride(M, N, stride_D0[i], is_row_major(d0_layout));
|
||||
stride_D1[i] = ck_tile::get_default_stride(M, N, stride_D1[i], is_row_major(d1_layout));
|
||||
|
||||
stride_Es[i] = ck_tile::get_default_stride(M, N, stride_Es[i], is_row_major(e_layout));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
|
||||
d0_m_n_tensors.push_back(ck_tile::HostTensor<D0DataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_D0[i], is_row_major(d0_layout))));
|
||||
d1_m_n_tensors.push_back(ck_tile::HostTensor<D1DataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_D1[i], is_row_major(d1_layout))));
|
||||
|
||||
e_m_n_tensors.push_back(ck_tile::HostTensor<EDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Es[i], is_row_major(e_layout))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " d0_m_n: " << d0_m_n_tensors[i].mDesc
|
||||
<< " d1_m_n: " << d1_m_n_tensors[i].mDesc << " e_m_n: " << e_m_n_tensors[i].mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{2.f, -2.f}(d0_m_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{2.f, -2.f}(d1_m_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));
|
||||
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_k_n_tensors[i]));
|
||||
|
||||
d0_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(d0_m_n_tensors[i]));
|
||||
d1_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(d1_m_n_tensors[i]));
|
||||
e_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(e_m_n_tensors[i]));
|
||||
|
||||
e_m_n_dev_buf[i]->SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {
|
||||
d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {stride_D0[i], stride_D1[i]};
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
ds_ptr_buf,
|
||||
p_e,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stridesDs,
|
||||
stride_Es[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWise>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name{"Grouped Gemm Multiple-D"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
|
||||
ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) {
|
||||
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_descs[j].M * gemm_descs[j].N;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_descs[j].M * gemm_descs[j].N;
|
||||
});
|
||||
|
||||
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
|
||||
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
|
||||
sizeof(EDataType) * gemm_descs[j].M * gemm_descs[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
std::vector<ck_tile::HostTensor<EDataType>> e_m_n_host_refs;
|
||||
e_m_n_host_refs.reserve(group_count);
|
||||
|
||||
// copy e_m_n_tensors result from device to host and initialize host tensors to zero
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
e_m_n_host_refs.push_back(ck_tile::HostTensor<EDataType>(
|
||||
host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], is_row_major(e_layout))));
|
||||
|
||||
e_m_n_host_refs[i].SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
CDElementWise>(
|
||||
a_m_k_tensors[i],
|
||||
b_k_n_tensors[i],
|
||||
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
|
||||
e_m_n_host_refs[i]);
|
||||
std::cout << "e_m_n_host_refs[i]: " << std::endl;
|
||||
e_m_n_host_refs[i].print_first_n(std::cout, 10);
|
||||
std::cout << std::endl;
|
||||
std::cout << "e_m_n_tensors[i]: " << std::endl;
|
||||
e_m_n_tensors[i].print_first_n(std::cout, 10);
|
||||
std::cout << std::endl;
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value);
|
||||
|
||||
pass &=
|
||||
ck_tile::check_err(e_m_n_tensors[i],
|
||||
e_m_n_host_refs[i],
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_grouped_gemm_json_results<ALayout, BLayout, ELayout>(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
group_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_grouped_gemm_multi_d_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string ds_layout = arg_parser.get_str("ds_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_multi_d_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user