mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Merge flatmm Operator with universal gemm (#2434)
* Initial commit * Adding new tile partitioner to flatmm * intermediate changes * debugging kernels * Updating flatmm example to universal gemm example * updated flatmm kernel to run via gemmKernel * update universal gemm to incorporate flatmm * debug * Fix flatmm call * Fixing other kernels and tests for API changes * clang formatted * fixing gemm tests * added test for flatmm and simplify kernel arguments * adding flatmm test * fix test for flatmm * simplify gemm kernel with flatmm * remove flatmm related files * addressing review comments and code clean up * resolving empty file * resolving empty file * clang formatted * addressing review comments * enable persistent kernel for flatmm * reverted the removed files for flatmm * reverted the removed files for flatmm * changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example * some more renames * clang formatted
This commit is contained in:
@@ -69,14 +69,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
@@ -90,27 +107,31 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::FlatmmHostArgs args;
|
||||
args.a_ptr = a_dev_buf.GetDeviceBuffer();
|
||||
args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_dev_buf.GetDeviceBuffer();
|
||||
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = flatmm_calc<ADataType,
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
FlatmmConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -159,6 +180,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
// persistent not added
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -204,13 +226,15 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<ADataType,
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
FlatmmConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
|
||||
Reference in New Issue
Block a user