Add rotating to mx examples

This commit is contained in:
Ding, Yi
2025-05-26 05:05:54 +00:00
parent fdfc9c6fd8
commit 40af523e2c
2 changed files with 16 additions and 4 deletions

View File

@@ -410,14 +410,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
"not consistent with the supported device_gemm arguments.");
}
std::size_t total_size =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() +
a_shuffled_scale.GetElementSpaceSizeInBytes() +
b_shuffled_scale.GetElementSpaceSizeInBytes();
const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size);
const int rotating_count = std::max(1, std::min(config.repeat, static_cast<int>(total_cnt)));
if(config.verbosity > 0)
{
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time = invoker.Run(
argument,
StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat});
float ave_time = invoker.Run(argument,
StreamConfig{nullptr,
config.time_kernel,
config.verbosity,
config.warm_up,
config.repeat,
rotating_count > 1,
rotating_count});
bool res_verified = true;
if(config.do_verification > 0)

View File

@@ -67,10 +67,10 @@ int profile_gemm_mx(int argc, char* argv[])
StrideC = std::stoi(argv[arg_index++]);
}
int KBatch = 1;
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
int KBatch = 1;
if(argc > arg_index)
{
KBatch = std::stoi(argv[arg_index++]);