avoid potential OOM caused by computation being slower than mover

avoid OOM (or shared vram invoking) caused by computation being slower than mover (GPU filled with loaded  but uncomputed tensors), by setting the max async overhead to 512MB
This commit is contained in:
lllyasviel
2024-02-22 08:24:23 -08:00
parent 167dbc6411
commit 6ebef20db3

View File

@@ -59,10 +59,19 @@ def main_thread_worker(weight, bias, signal):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
gc[id(finished_signal)] = (weight, bias, finished_signal)
size = weight.element_size() * weight.nelement()
if bias is not None:
size += bias.element_size() * bias.nelement()
gc[id(finished_signal)] = (weight, bias, finished_signal, size)
overhead = sum([l for k, (w, b, s, l) in gc.items()])
if overhead > 512 * 1024 * 1024:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
garbage = []
for k, (w, b, s) in gc.items():
for k, (w, b, s, l) in gc.items():
if s.query():
garbage.append(k)
@@ -76,7 +85,7 @@ def cleanup_cache():
if stream.current_stream is not None:
with stream.stream_context()(stream.current_stream):
for k, (w, b, s) in gc.items():
for k, (w, b, s, l) in gc.items():
stream.current_stream.wait_event(s)
stream.current_stream.synchronize()