diff --git a/backend/memory_management.py b/backend/memory_management.py index 3fa5e5f1..5a6098e0 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -382,7 +382,7 @@ class LoadedModel: raise e if not do_not_need_cpu_swap: - real_async_memory = 0 + memory_in_swap = 0 mem_counter = 0 for m in self.real_model.modules(): if hasattr(m, "parameters_manual_cast"): @@ -393,7 +393,7 @@ class LoadedModel: m.to(self.device) mem_counter += module_mem else: - real_async_memory += module_mem + memory_in_swap += module_mem m.to(self.model.offload_device) if PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device): m._apply(lambda x: x.pin_memory()) @@ -402,11 +402,9 @@ class LoadedModel: mem_counter += module_size(m) print(f"[Memory Management] Swap disabled for", type(m).__name__) - if stream.should_use_stream(): - print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (asynchronous method)") - else: - print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (blocked method)") - + swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU' + method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked' + print(f"[Memory Management] Loaded to {swap_flag} Swap: {memory_in_swap / (1024 * 1024):.2f} MB ({method_flag} method)") print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB") self.model_accelerated = True