From 6ebef20db3c15d9146b083ed28aba68ea098d16f Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 22 Feb 2024 08:24:23 -0800 Subject: [PATCH] avoid potential OOM caused by computation being slower than mover avoid OOM (or shared vram invoking) caused by computation being slower than mover (GPU filled with loaded but uncomputed tensors), by setting the max async overhead to 512MB --- ldm_patched/modules/ops.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 8cc41123..cd83fc5e 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -59,10 +59,19 @@ def main_thread_worker(weight, bias, signal): stream.current_stream.wait_event(signal) yield finished_signal = stream.current_stream.record_event() - gc[id(finished_signal)] = (weight, bias, finished_signal) + size = weight.element_size() * weight.nelement() + if bias is not None: + size += bias.element_size() * bias.nelement() + gc[id(finished_signal)] = (weight, bias, finished_signal, size) + + overhead = sum([l for k, (w, b, s, l) in gc.items()]) + + if overhead > 512 * 1024 * 1024: + stream.mover_stream.synchronize() + stream.current_stream.synchronize() garbage = [] - for k, (w, b, s) in gc.items(): + for k, (w, b, s, l) in gc.items(): if s.query(): garbage.append(k) @@ -76,7 +85,7 @@ def cleanup_cache(): if stream.current_stream is not None: with stream.stream_context()(stream.current_stream): - for k, (w, b, s) in gc.items(): + for k, (w, b, s, l) in gc.items(): stream.current_stream.wait_event(s) stream.current_stream.synchronize()