mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
[CK_TILE] Multiple-D GEMM example (#2219)
* Multiple d, initial commit
* Check Ds Layout
* Readme and clang format
* Update branch & conflicts
* Multiple D - fix clang-formatter
* Rename elemetwise_op
* Fix CI
* Code review part1
* Remove printf
* Remove unnecessary comment
* Add new tests with Col layout
* Review part 2
* Added support for Multiple D GEMM
* Update comment
* Remove maybe_unused
* Clang-format
* Review part 3
* Add comment to function
* Add comment to function: another
* Take number of params for a refrence function
* Remove additional d param for 0 tensor
* Change name of function
* Fix CI fails
[ROCm/composable_kernel commit: bd96ac9742]
This commit is contained in:
@@ -76,12 +76,17 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto Scheduler = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
static constexpr bool Persistent =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 9, std::false_type>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
@@ -165,9 +170,12 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipeline::BlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -326,17 +334,17 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs args;
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_E = stride_C;
|
||||
|
||||
invoke_gemm<PadM, PadN, PadK>(args, ck_tile::stream_config{nullptr, false});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user