[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails

[ROCm/composable_kernel commit: bd96ac9742]
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent ea36ae016e
commit 044a8560f7
34 changed files with 2267 additions and 285 deletions

View File

@@ -14,13 +14,17 @@
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
@@ -53,8 +57,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {

View File

@@ -252,10 +252,13 @@ auto create_args(int argc, char* argv[])
// host API
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent = false>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
bool Persistent = false,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);

View File

@@ -146,11 +146,14 @@ void permute_vectors_i4x4_b(Tensor& tensor)
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
@@ -165,41 +168,48 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_repeat,
bool persistent)
{
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{},
c_m_n_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
float ave_time;
if(persistent)
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
true>(
ave_time = gemm<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
true,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
else
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
false>(
ave_time = gemm<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
@@ -328,20 +338,27 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent);
invoke_gemm<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -15,13 +15,17 @@
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -30,24 +34,26 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
ELayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
@@ -85,9 +91,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
CLayout,
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -15,7 +15,16 @@
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
@@ -123,12 +132,16 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -139,6 +152,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2

View File

@@ -23,7 +23,16 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
@@ -44,20 +53,29 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_E = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_C = batch_stride_C;
args.batch_stride_E = batch_stride_C;
args.batch_count = batch_count;
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
float ave_time = batched_gemm<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Batched Gemm"};
@@ -169,22 +187,30 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
invoke_batched_gemm<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -1,6 +1,6 @@
# Grouped CShuffle GEMM
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation.
## build
```

View File

@@ -16,7 +16,16 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
@@ -130,9 +139,12 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -7,7 +7,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
@@ -53,7 +54,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
auto create_args(int argc, char* argv[])
{
@@ -82,7 +83,17 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);

View File

@@ -30,7 +30,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
@@ -44,7 +54,16 @@ float invoke_gemm(int n_warmup,
if constexpr(!Persistent)
{
// Regular version of grouped gemm
ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
ave_time = grouped_gemm<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
@@ -64,16 +83,18 @@ float invoke_gemm(int n_warmup,
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.c_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
arg.stride_C,
arg.k_batch});
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
arg.b_ptr,
{},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
{},
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
@@ -219,10 +240,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
}
invoke_gemm<ALayout, BLayout, CLayout, Persistent>(warmup, repeat, group_count, gemm_descs);
invoke_gemm<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
Persistent>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{

View File

@@ -0,0 +1 @@
add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp)

View File

@@ -0,0 +1,35 @@
#Multiple D GEMM
This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation.
## build
```
#in the root of ck_tile
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_d_fp16 -j
```
This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16`
## example
```
args:
-m M dimensions - (Default: 3840)
-n N dimensions - (Default: 4096)
-k K dimensions - (Default: 4096)
-a_layout Tensor A layout (default:R)
-b_layout Tensor B layout (default:C)
-ds_layout Tensor D layout (default:R)
-e_layout Tensor E layout (default:R)
-stride_a Tensor A strides - (Default: 0)
-stride_b Tensor B strides - (Default: 0)
-stride_e Tensor C strides - (Default: 0)
-stride_ds Tensor D strides - (Default: 0)
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
-warmup Number of iterations before benchmark the kernel. (Default: 10)
-repeat Number of iterations to benchmark the kernel. (Default: 100)
-kbatch kbatch for SplitK. (Default 1)
```

View File

@@ -0,0 +1,296 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_multi_d_fp16.hpp"
#include "utils.hpp"
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& s) -> float
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
if(tail_num == ck_tile::TailNumber::One)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
auto check_tail = [&](auto... TNs) {
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return ave_time;
}
#include "run_gemm_multi_d_fp16_example.inc"
int main(int argc, char* argv[]) { return !run_multiple_d_gemm_example(argc, argv); }

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using AccDataType = float;
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Col by default")
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_ds", "0", "Tensor Ds stride")
.insert("stride_e", "0", "Tensor E stride")
.insert("v", "1", "0. No validation, 1. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
using gemm_multi_d_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise>
float gemm_multi_d(const gemm_multi_d_kargs& kargs, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,247 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
const void* b_k_n_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
void* e_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t StrideA,
ck_tile::index_t StrideB,
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
ck_tile::index_t StrideE,
int n_warmup,
int n_repeat,
int k_batch)
{
gemm_multi_d_kargs gemm_descs({a_m_k_dev_buf,
b_k_n_dev_buf,
ds_m_n_dev_buf,
e_m_n_dev_buf,
k_batch,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE});
float ave_time = gemm_multi_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Gemm Multiple-D"};
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
std::size_t flop = 0, num_btype = 0;
flop += std::size_t(2) * M * N * K;
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
});
num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Run Gemm Multiple-D kernel with:\n";
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
<< "\n";
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< "\n";
return ave_time;
}
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout>
int run_multiple_d_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const D0Layout d0_layout = D0Layout{},
const D1Layout d1_layout = D1Layout{},
const ELayout e_layout = ELayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
using CDElementWiseFn = MultiplyMultiply;
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t StrideA = arg_parser.get_int("stride_a");
ck_tile::index_t StrideB = arg_parser.get_int("stride_b");
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
ck_tile::index_t StrideD0 = StrideD;
ck_tile::index_t StrideD1 = StrideD;
const int n_warmup = arg_parser.get_int("warmup");
const int n_repeat = arg_parser.get_int("repeat");
const int k_batch = arg_parser.get_int("kbatch");
StrideA = get_default_stride(M, K, StrideA, is_row_major(a_layout));
StrideB = get_default_stride(K, N, StrideB, is_row_major(b_layout));
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
ck_tile::HostTensor<ADataType> a_m_k_tesnor(
host_tensor_descriptor(M, K, StrideA, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_tensors(
host_tensor_descriptor(K, N, StrideB, is_row_major(b_layout)));
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tesnor);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k_tesnor.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k_tesnor.mData.data());
b_k_n_dev_buf.ToDevice(b_k_n_tensors.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
invoke_gemm_multi_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
stridesDs,
StrideE,
n_warmup,
n_repeat,
k_batch);
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
e_m_n_host_ref.SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
CDElementWiseFn>(
a_m_k_tesnor, b_k_n_tensors, {d0_m_n_tensors, d1_m_n_tensors}, e_m_n_host_ref);
bool pass{true};
if(arg_parser.get_int("v"))
{
const float max_accumulated_value =
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
pass &= ck_tile::check_err(e_m_n_device_result,
e_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< std::endl;
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_multiple_d_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string ds_layout = arg_parser.get_str("ds_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C" && ds_layout == "R")
{
return run_multiple_d_gemm_example_with_layouts(
argc, argv, Row{}, Col{}, Row{}, Row{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for provided tensors!");
}
}

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
{
const float x0_f = ck_tile::type_convert<float>(c) * ck_tile::type_convert<float>(d0) *
ck_tile::type_convert<float>(d1);
e = ck_tile::type_convert<E>(x0_f);
}
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}

View File

@@ -18,5 +18,6 @@ add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(19_gemm_multi_d)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_copy)