diff --git a/example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc index e07e5e536f..1270785693 100644 --- a/example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc @@ -14,20 +14,28 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, + ck_tile::index_t batch_stride_A, + ck_tile::index_t batch_stride_B, + ck_tile::index_t batch_stride_C, + ck_tile::index_t batch_count, int n_warmup, int n_repeat) { batched_gemm_basic_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; + args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); + args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); + args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); + args.kbatch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + args.batch_stride_A = batch_stride_A; + args.batch_stride_B = batch_stride_B; + args.batch_stride_C = batch_stride_C; + args.batch_count = batch_count; float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); @@ -63,8 +71,20 @@ int run_batched_gemm_example(int argc, char* argv[]) ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t batch_size = arg_parser.get_int("b"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + + ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a"); + ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); + ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); + ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); + + std::cout << "Received args: " << std::endl; + std::cout << "batch_stride_A: " << batch_stride_A << '\n' + << "batch_stride_B: " << batch_stride_B << '\n' + << "batch_stride_C: " << batch_stride_C << '\n' + << "batch_count: " << batch_count << std::endl; + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); using ALayout = ck_tile::tensor_layout::gemm::RowMajor; using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[]) stride_B, stride_C, batch_size, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, n_warmup, n_repeat);