diff --git a/backend/memory_management.py b/backend/memory_management.py index 5b0e66a5..c6ff6b95 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -363,7 +363,7 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_memory_required(self, device): + def model_memory_required(self, device=None): return module_size(self.model.model, exclude_device=device) def model_load(self, model_gpu_memory_when_using_cpu_swap=-1): @@ -538,7 +538,7 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: unload_model_clones(loaded_model.model) - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required() for device in total_memory_required: if device != torch.device("cpu"): diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 238744b2..8880b318 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -279,7 +279,7 @@ class LoraLoader: self.dirty = True return list(p) - def refresh(self, target_device=None, offload_device=torch.cpu): + def refresh(self, target_device=None, offload_device=torch.device('cpu')): if not self.dirty: return