mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK][host] limit the rotating count to prevent oom (#3089)
* [CK][host] limit the rotating count to prevent oom * add numeric header for accumulate
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
@@ -28,12 +29,12 @@ struct RotatingMemWrapperMultiABD
|
||||
|
||||
RotatingMemWrapperMultiABD() = delete;
|
||||
RotatingMemWrapperMultiABD(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t rotating_count_hint,
|
||||
std::array<std::size_t, NumAs> size_as_,
|
||||
std::array<std::size_t, NumBs> size_bs_,
|
||||
std::array<std::size_t, NumDs> size_ds_)
|
||||
: arg(arg_),
|
||||
rotating_count(rotating_count_),
|
||||
rotating_count(rotating_count_hint),
|
||||
size_as(size_as_),
|
||||
size_bs(size_bs_),
|
||||
size_ds(size_ds_)
|
||||
@@ -41,6 +42,14 @@ struct RotatingMemWrapperMultiABD
|
||||
p_as_grids.push_back(arg.p_as_grid);
|
||||
p_bs_grids.push_back(arg.p_bs_grid);
|
||||
p_ds_grids.push_back(arg.p_ds_grid);
|
||||
|
||||
// limit the rotating count to prevent oom
|
||||
const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
|
||||
std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
|
||||
std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
|
||||
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
|
||||
rotating_count = std::min(rotating_count, max_rotating_count);
|
||||
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
@@ -171,12 +180,12 @@ struct RotatingMemWrapperMultiD
|
||||
|
||||
RotatingMemWrapperMultiD() = delete;
|
||||
RotatingMemWrapperMultiD(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t rotating_count_hint,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_,
|
||||
std::array<std::size_t, NumDs> size_ds_)
|
||||
: arg(arg_),
|
||||
rotating_count(rotating_count_),
|
||||
rotating_count(rotating_count_hint),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_),
|
||||
size_ds(size_ds_)
|
||||
@@ -184,6 +193,13 @@ struct RotatingMemWrapperMultiD
|
||||
p_a_grids.push_back(arg.p_a_grid);
|
||||
p_b_grids.push_back(arg.p_b_grid);
|
||||
p_ds_grids.push_back(arg.p_ds_grid);
|
||||
|
||||
// limit the rotating count to prevent oom
|
||||
const uint64_t footprint =
|
||||
std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
|
||||
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
|
||||
rotating_count = std::min(rotating_count, max_rotating_count);
|
||||
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
@@ -286,13 +302,19 @@ struct RotatingMemWrapper
|
||||
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t rotating_count_hint,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
|
||||
: arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(arg.p_a_grid);
|
||||
p_b_grids.push_back(arg.p_b_grid);
|
||||
|
||||
// limit the rotating count to prevent oom
|
||||
const uint64_t footprint = (size_a + size_b);
|
||||
const uint64_t max_rotating_count = (1ULL << 31) / footprint;
|
||||
rotating_count = std::min(rotating_count, max_rotating_count);
|
||||
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user