From 0d4c6c2c13f255495ca0f6a9f1f4e1918662ae0a Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 24 Oct 2025 16:13:23 +0000 Subject: [PATCH] Merge commit 'f39626fcf72d0188946040fe6441437415707343' into develop --- include/ck/host_utility/flush_cache.hpp | 34 +++++++++++++++++++---- include/ck_tile/host/rotating_buffers.hpp | 9 ++++-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 5da447125e..c98948edb7 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include #include @@ -28,12 +29,12 @@ struct RotatingMemWrapperMultiABD RotatingMemWrapperMultiABD() = delete; RotatingMemWrapperMultiABD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::array size_as_, std::array size_bs_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_as(size_as_), size_bs(size_bs_), size_ds(size_ds_) @@ -41,6 +42,14 @@ struct RotatingMemWrapperMultiABD p_as_grids.push_back(arg.p_as_grid); p_bs_grids.push_back(arg.p_bs_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) + + std::accumulate(size_bs.begin(), size_bs.end(), 0UL) + + std::accumulate(size_ds.begin(), size_ds.end(), 0UL); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -171,12 +180,12 @@ struct RotatingMemWrapperMultiD RotatingMemWrapperMultiD() = delete; RotatingMemWrapperMultiD(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array size_ds_) : arg(arg_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_), size_ds(size_ds_) @@ -184,6 +193,13 @@ struct RotatingMemWrapperMultiD p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); p_ds_grids.push_back(arg.p_ds_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = + std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { @@ -286,13 +302,19 @@ struct RotatingMemWrapper RotatingMemWrapper() = delete; RotatingMemWrapper(Argument& arg_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_) - : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_) { p_a_grids.push_back(arg.p_a_grid); p_b_grids.push_back(arg.p_b_grid); + + // limit the rotating count to prevent oom + const uint64_t footprint = (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + for(size_t i = 1; i < rotating_count; i++) { { diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp index 154d67fb8e..601b8f2378 100644 --- a/include/ck_tile/host/rotating_buffers.hpp +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -29,12 +29,12 @@ struct RotatingMemWrapper RotatingMemWrapper() = delete; RotatingMemWrapper(const void* a_ptr_, const void* b_ptr_, - std::size_t rotating_count_, + std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_) : a_ptr(a_ptr_), b_ptr(b_ptr_), - rotating_count(rotating_count_), + rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_) { @@ -42,6 +42,11 @@ struct RotatingMemWrapper p_a_grids.push_back(a_ptr); p_b_grids.push_back(b_ptr); + // limit the rotating count to prevent oom + const uint64_t footprint = (size_a + size_b); + const uint64_t max_rotating_count = (1ULL << 31) / footprint; + rotating_count = std::min(rotating_count, max_rotating_count); + // Create (rotating_count - 1) additional copies at different memory addresses for(size_t i = 1; i < rotating_count; i++) {