mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[CK_TILE] Multiple-D GEMM example (#2219)
* Multiple d, initial commit * Check Ds Layout * Readme and clang format * Update branch & conflicts * Multiple D - fix clang-formatter * Rename elemetwise_op * Fix CI * Code review part1 * Remove printf * Remove unnecessary comment * Add new tests with Col layout * Review part 2 * Added support for Multiple D GEMM * Update comment * Remove maybe_unused * Clang-format * Review part 3 * Add comment to function * Add comment to function: another * Take number of params for a refrence function * Remove additional d param for 0 tensor * Change name of function * Fix CI fails
This commit is contained in:
@@ -15,7 +15,16 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "batched_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
@@ -123,12 +132,16 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -139,6 +152,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -23,7 +23,16 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -44,20 +53,29 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::BatchedGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_E = stride_C;
|
||||
args.batch_stride_A = batch_stride_A;
|
||||
args.batch_stride_B = batch_stride_B;
|
||||
args.batch_stride_C = batch_stride_C;
|
||||
args.batch_stride_E = batch_stride_C;
|
||||
args.batch_count = batch_count;
|
||||
|
||||
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
|
||||
float ave_time = batched_gemm<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Batched Gemm"};
|
||||
@@ -169,22 +187,30 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_batched_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user