mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Merge commit '88d72178d6739c7e277074e5f9bb5d1e59bf0152' into develop
This commit is contained in:
@@ -458,7 +458,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("bench_time_ms", "0", "benchmark time in ms, defaults to 0 ms");
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1", "rotating count, defaults to 1");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -184,7 +184,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool persistent,
|
||||
int bench_time_ms)
|
||||
bool flush_cache,
|
||||
int rotating_count)
|
||||
{
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
@@ -214,7 +215,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -232,7 +233,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, true, 50, bench_time_ms});
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -303,7 +304,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
int bench_time_ms = arg_parser.get_int("bench_time_ms");
|
||||
bool flush_cache = arg_parser.get_bool("flush_cache");
|
||||
int rotating_count = arg_parser.get_int("rotating_count");
|
||||
|
||||
const bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
@@ -422,7 +424,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent,
|
||||
bench_time_ms);
|
||||
flush_cache,
|
||||
rotating_count);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
@@ -168,7 +168,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
|
||||
@@ -120,7 +120,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_preprocess(
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
Reference in New Issue
Block a user