From 40af523e2c82833325f96027dc4199eaf01d3abc Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Mon, 26 May 2025 05:05:54 +0000 Subject: [PATCH] Add rotating to mx examples --- .../67_gemm_microscaling/gemm_mx_common.hpp | 18 +++++++++++++++--- profiler/src/profile_gemm_mx.cpp | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 5e15ccd04f..dda25cc1df 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -410,14 +410,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c "not consistent with the supported device_gemm arguments."); } + std::size_t total_size = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + + a_shuffled_scale.GetElementSpaceSizeInBytes() + + b_shuffled_scale.GetElementSpaceSizeInBytes(); + const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size); + const int rotating_count = std::max(1, std::min(config.repeat, static_cast(total_cnt))); if(config.verbosity > 0) { std::cout << "Computing GEMM on device..." << std::endl << std::endl; } - float ave_time = invoker.Run( - argument, - StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat}); + float ave_time = invoker.Run(argument, + StreamConfig{nullptr, + config.time_kernel, + config.verbosity, + config.warm_up, + config.repeat, + rotating_count > 1, + rotating_count}); bool res_verified = true; if(config.do_verification > 0) diff --git a/profiler/src/profile_gemm_mx.cpp b/profiler/src/profile_gemm_mx.cpp index 7fb76bd76d..d5099cd0c9 100644 --- a/profiler/src/profile_gemm_mx.cpp +++ b/profiler/src/profile_gemm_mx.cpp @@ -67,10 +67,10 @@ int profile_gemm_mx(int argc, char* argv[]) StrideC = std::stoi(argv[arg_index++]); } + int KBatch = 1; int n_warmup = 1; int n_iter = 10; uint64_t rotating = 0; - int KBatch = 1; if(argc > arg_index) { KBatch = std::stoi(argv[arg_index++]);