revise stream logics

This commit is contained in:
layerdiffusion
2024-08-08 18:45:36 -07:00
parent d3b81924df
commit 60c5aea11b
4 changed files with 30 additions and 24 deletions

View File

@@ -21,7 +21,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
if skip_dtype:
target_dtype = None
if stream.using_stream:
if stream.should_use_stream():
with stream.stream_context()(stream.mover_stream):
if layer.weight is not None:
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
@@ -39,7 +39,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
@contextlib.contextmanager
def main_stream_worker(weight, bias, signal):
if not stream.using_stream or signal is None:
if signal is None or not stream.should_use_stream():
yield
return
@@ -60,7 +60,7 @@ def main_stream_worker(weight, bias, signal):
def cleanup_cache():
if not stream.using_stream:
if not stream.should_use_stream():
return
stream.current_stream.synchronize()