diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp index 154d67fb8e..601b8f2378 100644 --- a/include/ck_tile/host/rotating_buffers.hpp +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -29,12 +29,12 @@ struct RotatingMemWrapper RotatingMemWrapper() = delete; RotatingMemWrapper(const void* a_ptr_, const void* b_ptr_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_) : a_ptr(a_ptr_), b_ptr(b_ptr_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_) { @@ -42,6 +42,11 @@ struct RotatingMemWrapper p_a_grids.push_back(a_ptr); p_b_grids.push_back(b_ptr); + // limit the rotating count to prevent oom + const uint64_t footprint = (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + // Create (rotating_count - 1) additional copies at different memory addresses for(size_t i = 1; i < rotating_count; i++) {