mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Batched gemm - passed batch args
This commit is contained in:
@@ -14,20 +14,28 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::index_t batch_stride_A,
|
||||
ck_tile::index_t batch_stride_B,
|
||||
ck_tile::index_t batch_stride_C,
|
||||
ck_tile::index_t batch_count,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
batched_gemm_basic_args args;
|
||||
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.kbatch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.kbatch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.batch_stride_A = batch_stride_A;
|
||||
args.batch_stride_B = batch_stride_B;
|
||||
args.batch_stride_C = batch_stride_C;
|
||||
args.batch_count = batch_count;
|
||||
|
||||
float ave_time = gemm_calc<ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
@@ -63,8 +71,20 @@ int run_batched_gemm_example(int argc, char* argv[])
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t batch_size = arg_parser.get_int("b");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
|
||||
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
|
||||
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
|
||||
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
|
||||
|
||||
std::cout << "Received args: " << std::endl;
|
||||
std::cout << "batch_stride_A: " << batch_stride_A << '\n'
|
||||
<< "batch_stride_B: " << batch_stride_B << '\n'
|
||||
<< "batch_stride_C: " << batch_stride_C << '\n'
|
||||
<< "batch_count: " << batch_count << std::endl;
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[])
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_size,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user