[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent 3a0cb27966
commit bd96ac9742
34 changed files with 2267 additions and 285 deletions

View File

@@ -1,6 +1,6 @@
# Grouped CShuffle GEMM
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation.
## build
```

View File

@@ -16,7 +16,16 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
@@ -130,9 +139,12 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -7,7 +7,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
@@ -53,7 +54,7 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
auto create_args(int argc, char* argv[])
{
@@ -82,7 +83,17 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);

View File

@@ -30,7 +30,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
@@ -44,7 +54,16 @@ float invoke_gemm(int n_warmup,
if constexpr(!Persistent)
{
// Regular version of grouped gemm
ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
ave_time = grouped_gemm<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
gemm_workspace.GetDeviceBuffer());
@@ -64,16 +83,18 @@ float invoke_gemm(int n_warmup,
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.c_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
arg.stride_C,
arg.k_batch});
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
arg.b_ptr,
{},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
{},
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
@@ -219,10 +240,19 @@ int run_grouped_gemm_example_with_layouts(int argc,
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
}
invoke_gemm<ALayout, BLayout, CLayout, Persistent>(warmup, repeat, group_count, gemm_descs);
invoke_gemm<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
Persistent>(warmup, repeat, group_count, gemm_descs);
for(int i = 0; i < group_count; i++)
{