mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK_TILE] Multiple-D GEMM example (#2219)
* Multiple d, initial commit * Check Ds Layout * Readme and clang format * Update branch & conflicts * Multiple D - fix clang-formatter * Rename elemetwise_op * Fix CI * Code review part1 * Remove printf * Remove unnecessary comment * Add new tests with Col layout * Review part 2 * Added support for Multiple D GEMM * Update comment * Remove maybe_unused * Clang-format * Review part 3 * Add comment to function * Add comment to function: another * Take number of params for a refrence function * Remove additional d param for 0 tensor * Change name of function * Fix CI fails
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemm : public ::testing::Test
|
||||
@@ -23,6 +24,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
// Get the persistent value from ck_tile::bool_constant
|
||||
using PersistentType = std::tuple_element_t<7, Tuple>;
|
||||
@@ -48,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
@@ -127,9 +130,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -256,9 +262,12 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -428,7 +437,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
@@ -442,16 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const bool splitk = gemm_descs[0].k_batch > 1;
|
||||
for(const auto& arg : gemm_descs)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.c_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_C,
|
||||
arg.k_batch});
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
{},
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
{},
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, false, 1};
|
||||
ck_tile::hip_check_error(
|
||||
|
||||
Reference in New Issue
Block a user