diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 8cc41123..cd83fc5e 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -59,10 +59,19 @@ def main_thread_worker(weight, bias, signal): stream.current_stream.wait_event(signal) yield finished_signal = stream.current_stream.record_event() - gc[id(finished_signal)] = (weight, bias, finished_signal) + size = weight.element_size() * weight.nelement() + if bias is not None: + size += bias.element_size() * bias.nelement() + gc[id(finished_signal)] = (weight, bias, finished_signal, size) + + overhead = sum([l for k, (w, b, s, l) in gc.items()]) + + if overhead > 512 * 1024 * 1024: + stream.mover_stream.synchronize() + stream.current_stream.synchronize() garbage = [] - for k, (w, b, s) in gc.items(): + for k, (w, b, s, l) in gc.items(): if s.query(): garbage.append(k) @@ -76,7 +85,7 @@ def cleanup_cache(): if stream.current_stream is not None: with stream.stream_context()(stream.current_stream): - for k, (w, b, s) in gc.items(): + for k, (w, b, s, l) in gc.items(): stream.current_stream.wait_event(s) stream.current_stream.synchronize()