[CK][host] limit the rotating count to prevent oom (#3089)

* [CK][host] limit the rotating count to prevent oom

* add numeric header for accumulate
This commit is contained in:
Max Podkorytov
2025-10-24 08:55:54 -07:00
committed by GitHub
parent fdcc1f75c3
commit f39626fcf7

View File

@@ -4,6 +4,7 @@
#pragma once
#include <hip/hip_runtime.h>
#include <numeric>
#include <set>
#include <vector>
@@ -28,12 +29,12 @@ struct RotatingMemWrapperMultiABD
RotatingMemWrapperMultiABD() = delete;
RotatingMemWrapperMultiABD(Argument& arg_,
std::size_t rotating_count_,
std::size_t rotating_count_hint,
std::array<std::size_t, NumAs> size_as_,
std::array<std::size_t, NumBs> size_bs_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
rotating_count(rotating_count_hint),
size_as(size_as_),
size_bs(size_bs_),
size_ds(size_ds_)
@@ -41,6 +42,14 @@ struct RotatingMemWrapperMultiABD
p_as_grids.push_back(arg.p_as_grid);
p_bs_grids.push_back(arg.p_bs_grid);
p_ds_grids.push_back(arg.p_ds_grid);
// limit the rotating count to prevent oom
const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
rotating_count = std::min(rotating_count, max_rotating_count);
for(size_t i = 1; i < rotating_count; i++)
{
{
@@ -171,12 +180,12 @@ struct RotatingMemWrapperMultiD
RotatingMemWrapperMultiD() = delete;
RotatingMemWrapperMultiD(Argument& arg_,
std::size_t rotating_count_,
std::size_t rotating_count_hint,
std::size_t size_a_,
std::size_t size_b_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
rotating_count(rotating_count_hint),
size_a(size_a_),
size_b(size_b_),
size_ds(size_ds_)
@@ -184,6 +193,13 @@ struct RotatingMemWrapperMultiD
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
p_ds_grids.push_back(arg.p_ds_grid);
// limit the rotating count to prevent oom
const uint64_t footprint =
std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
rotating_count = std::min(rotating_count, max_rotating_count);
for(size_t i = 1; i < rotating_count; i++)
{
{
@@ -286,13 +302,19 @@ struct RotatingMemWrapper
RotatingMemWrapper() = delete;
RotatingMemWrapper(Argument& arg_,
std::size_t rotating_count_,
std::size_t rotating_count_hint,
std::size_t size_a_,
std::size_t size_b_)
: arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
: arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
{
p_a_grids.push_back(arg.p_a_grid);
p_b_grids.push_back(arg.p_b_grid);
// limit the rotating count to prevent oom
const uint64_t footprint = (size_a + size_b);
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
rotating_count = std::min(rotating_count, max_rotating_count);
for(size_t i = 1; i < rotating_count; i++)
{
{