From 99ad6f60e4ad13d20b03badfa4f11357701d8ac2 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Fri, 24 Oct 2025 08:55:54 -0700 Subject: [PATCH] [CK][host] limit the rotating count to prevent oom (#3089) * [CK][host] limit the rotating count to prevent oom * add numeric header for accumulate [ROCm/composable_kernel commit: f39626fcf72d0188946040fe6441437415707343] --- include/ck/host_utility/flush_cache.hpp | 34 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 5da447125e..c98948edb7 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -28,12 +29,12 @@ struct RotatingMemWrapperMultiABD RotatingMemWrapperMultiABD() = delete; RotatingMemWrapperMultiABD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::array size_as_, std::array size_bs_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_as(size_as_), size_bs(size_bs_), size_ds(size_ds_) @@ -41,6 +42,14 @@ struct RotatingMemWrapperMultiABD p_as_grids.push_back(arg.p_as_grid); p_bs_grids.push_back(arg.p_bs_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) + + std::accumulate(size_bs.begin(), size_bs.end(), 0UL) + + std::accumulate(size_ds.begin(), size_ds.end(), 0UL); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -171,12 +180,12 @@ struct RotatingMemWrapperMultiD RotatingMemWrapperMultiD() = delete; RotatingMemWrapperMultiD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_), size_ds(size_ds_) @@ -184,6 +193,13 @@ struct RotatingMemWrapperMultiD p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = + std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -286,13 +302,19 @@ struct RotatingMemWrapper RotatingMemWrapper() = delete; RotatingMemWrapper(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_) - : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_) { p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { {