[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent 3a0cb27966
commit bd96ac9742
34 changed files with 2267 additions and 285 deletions

View File

@@ -11,6 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename Tuple>
class TestCkTileBatchedGemm : public ::testing::Test
@@ -23,6 +24,8 @@ class TestCkTileBatchedGemm : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
@@ -102,9 +105,12 @@ class TestCkTileBatchedGemm : public ::testing::Test
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -239,17 +245,17 @@ class TestCkTileBatchedGemm : public ::testing::Test
ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = 1;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = StrideA;
args.stride_B = StrideB;
args.stride_C = StrideC;
args.stride_E = StrideC;
args.batch_stride_A = BatchStrideA;
args.batch_stride_B = BatchStrideB;
args.batch_stride_C = BatchStrideC;
args.batch_stride_E = BatchStrideC;
args.batch_count = BatchCount;
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,