This commit is contained in:
layerdiffusion
2024-08-30 09:41:36 -07:00
parent 1e0e861b0d
commit f04666b19b

View File

@@ -350,7 +350,7 @@ def bake_gguf_model(model):
return model
def module_size(module, exclude_device=None, return_split=False):
def module_size(module, exclude_device=None, include_device=None, return_split=False):
module_mem = 0
weight_mem = 0
weight_patterns = ['weight']
@@ -362,6 +362,10 @@ def module_size(module, exclude_device=None, return_split=False):
if t.device == exclude_device:
continue
if include_device is not None:
if t.device != include_device:
continue
element_size = t.element_size()
if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
@@ -439,12 +443,13 @@ class LoadedModel:
self.memory_required = memory_required
self.model_accelerated = False
self.device = model.load_device
self.inclusive_memory = 0
self.exclusive_memory = 0
def model_memory(self):
return self.model.model_size()
def model_memory_required(self, device=None):
return module_size(self.model.model, exclude_device=device)
def compute_inclusive_exclusive_memory(self):
self.inclusive_memory = module_size(self.model.model, include_device=self.device)
self.exclusive_memory = module_size(self.model.model, exclude_device=self.device)
return
def model_load(self, model_gpu_memory_when_using_cpu_swap=-1):
patch_model_to = None
@@ -548,9 +553,6 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) > 0:
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload:
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
@@ -632,10 +634,13 @@ def load_models_gpu(models, memory_required=0):
return
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required()
total_memory_required = {}
for loaded_model in models_to_load:
loaded_model.compute_inclusive_exclusive_memory()
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
for device in total_memory_required:
if device != torch.device("cpu"):
@@ -652,16 +657,20 @@ def load_models_gpu(models, memory_required=0):
model_gpu_memory_when_using_cpu_swap = -1
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_memory = loaded_model.model_memory_required(torch_dev)
model_require = loaded_model.exclusive_memory
previously_loaded = loaded_model.inclusive_memory
current_free_mem = get_free_memory(torch_dev)
inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
estimated_remaining_memory = current_free_mem - model_require - inference_memory
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory)
if previously_loaded > 0:
model_gpu_memory_when_using_cpu_swap = previously_loaded
if vram_set_state == VRAMState.NO_VRAM:
model_gpu_memory_when_using_cpu_swap = 0