diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 5da447125e..c98948edb7 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -28,12 +29,12 @@ struct RotatingMemWrapperMultiABD RotatingMemWrapperMultiABD() = delete; RotatingMemWrapperMultiABD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::array size_as_, std::array size_bs_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_as(size_as_), size_bs(size_bs_), size_ds(size_ds_) @@ -41,6 +42,14 @@ struct RotatingMemWrapperMultiABD p_as_grids.push_back(arg.p_as_grid); p_bs_grids.push_back(arg.p_bs_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) + + std::accumulate(size_bs.begin(), size_bs.end(), 0UL) + + std::accumulate(size_ds.begin(), size_ds.end(), 0UL); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -171,12 +180,12 @@ struct RotatingMemWrapperMultiD RotatingMemWrapperMultiD() = delete; RotatingMemWrapperMultiD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_), size_ds(size_ds_) @@ -184,6 +193,13 @@ struct RotatingMemWrapperMultiD p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = + std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -286,13 +302,19 @@ struct RotatingMemWrapper RotatingMemWrapper() = delete; RotatingMemWrapper(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_) - : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_) { p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { {