diff --git a/backend/memory_management.py b/backend/memory_management.py index 29e5239c..e1c1ffba 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -340,8 +340,6 @@ class LoadedModel: raise e if not disable_async_load: - flag = 'ASYNC' if stream.using_stream else 'SYNC' - print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024)) real_async_memory = 0 mem_counter = 0 for m in self.real_model.modules(): @@ -360,9 +358,14 @@ class LoadedModel: elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) - print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__) - print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024)) - print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) + print(f"[Memory Management] Swap disabled for", type(m).__name__) + + if stream.should_use_stream(): + print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (asynchronous method)") + else: + print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (blocked method)") + + print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB") self.model_accelerated = True @@ -390,8 +393,12 @@ class LoadedModel: return self.model is other.model # and self.memory_required == other.memory_required +current_inference_memory = 1024 * 1024 * 1024 + + def minimum_inference_memory(): - return 1024 * 1024 * 1024 + global current_inference_memory + return current_inference_memory def unload_model_clones(model): @@ -487,17 +494,17 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_memory = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - minimal_inference_memory = minimum_inference_memory() - estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory + inference_memory = minimum_inference_memory() + estimated_remaining_memory = current_free_mem - model_memory - inference_memory - print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024)) - print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024)) - print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024)) - print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024)) + print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB") + print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB") + print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB") + print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB") if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM - async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3 + async_kept_memory = (current_free_mem - inference_memory) / 1.3 async_kept_memory = int(max(0, async_kept_memory)) if vram_set_state == VRAMState.NO_VRAM: diff --git a/backend/operations.py b/backend/operations.py index 250285be..db411dee 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -21,7 +21,7 @@ def weights_manual_cast(layer, x, skip_dtype=False): if skip_dtype: target_dtype = None - if stream.using_stream: + if stream.should_use_stream(): with stream.stream_context()(stream.mover_stream): if layer.weight is not None: weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking) @@ -39,7 +39,7 @@ def weights_manual_cast(layer, x, skip_dtype=False): @contextlib.contextmanager def main_stream_worker(weight, bias, signal): - if not stream.using_stream or signal is None: + if signal is None or not stream.should_use_stream(): yield return @@ -60,7 +60,7 @@ def main_stream_worker(weight, bias, signal): def cleanup_cache(): - if not stream.using_stream: + if not stream.should_use_stream(): return stream.current_stream.synchronize() diff --git a/backend/stream.py b/backend/stream.py index e051d442..f3fcd7bc 100644 --- a/backend/stream.py +++ b/backend/stream.py @@ -52,11 +52,10 @@ def get_new_stream(): return None -current_stream = None -mover_stream = None -using_stream = False +def should_use_stream(): + return stream_activated and current_stream is not None and mover_stream is not None -if args.cuda_stream: - current_stream = get_current_stream() - mover_stream = get_new_stream() - using_stream = current_stream is not None and mover_stream is not None + +current_stream = get_current_stream() +mover_stream = get_new_stream() +stream_activated = args.cuda_stream diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 72dd9dfa..5e43b44e 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -58,7 +58,7 @@ def initialize_forge(): modules_forge.patch_basic.patch_all_basics() from backend import stream - print('CUDA Stream Activated: ', stream.using_stream) + print('CUDA Using Stream:', stream.should_use_stream()) from modules_forge.shared import diffusers_dir