[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent 3a0cb27966
commit bd96ac9742
34 changed files with 2267 additions and 285 deletions

View File

@@ -11,6 +11,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
template <typename Tuple>
class TestCkTileGroupedGemm : public ::testing::Test
@@ -23,6 +24,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
// Get the persistent value from ck_tile::bool_constant
using PersistentType = std::tuple_element_t<7, Tuple>;
@@ -48,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
static const ck_tile::index_t K_Warp_Tile = 16;
};
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
@@ -127,9 +130,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -256,9 +262,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -428,7 +437,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
gemm_descs.push_back(
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
}
ck_tile::DeviceMem gemm_workspace;
@@ -442,16 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test
const bool splitk = gemm_descs[0].k_batch > 1;
for(const auto& arg : gemm_descs)
{
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
arg.b_ptr,
arg.c_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
arg.stride_C,
arg.k_batch});
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
arg.b_ptr,
{},
arg.e_ptr,
arg.M,
arg.N,
arg.K,
arg.stride_A,
arg.stride_B,
{},
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, false, 1};
ck_tile::hip_check_error(