revise stream

This commit is contained in:
layerdiffusion
2024-08-08 20:18:56 -07:00
parent 02ffb04649
commit 6f254f3599

View File

@@ -322,24 +322,24 @@ class LoadedModel:
def model_memory_required(self, device):
return module_size(self.model.model, exclude_device=device)
def model_load(self, async_kept_memory=-1):
def model_load(self, model_gpu_memory_when_using_cpu_swap=-1):
patch_model_to = None
disable_async_load = async_kept_memory < 0
do_not_need_cpu_swap = model_gpu_memory_when_using_cpu_swap < 0
if disable_async_load:
if do_not_need_cpu_swap:
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
try:
self.real_model = self.model.patch_model(device_to=patch_model_to) # TODO: do something with loras and offloading to CPU
self.real_model = self.model.patch_model(device_to=patch_model_to)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if not disable_async_load:
if not do_not_need_cpu_swap:
real_async_memory = 0
mem_counter = 0
for m in self.real_model.modules():
@@ -347,7 +347,7 @@ class LoadedModel:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
module_mem = module_size(m)
if mem_counter + module_mem < async_kept_memory:
if mem_counter + module_mem < model_gpu_memory_when_using_cpu_swap:
m.to(self.device)
mem_counter += module_mem
else:
@@ -438,6 +438,18 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
maximum_memory_available = current_free_mem - inference_memory
k_1GB = float(inference_memory / (1024 * 1024 * 1024))
k_1GB = max(0.0, min(1.0, k_1GB))
adaptive_safe_factor = 1.0 - 0.23 * k_1GB
suggestion = maximum_memory_available * adaptive_safe_factor
return int(max(0, suggestion))
def load_models_gpu(models, memory_required=0):
global vram_state
@@ -489,7 +501,7 @@ def load_models_gpu(models, memory_required=0):
else:
vram_set_state = vram_state
async_kept_memory = -1
model_gpu_memory_when_using_cpu_swap = -1
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_memory = loaded_model.model_memory_required(torch_dev)
@@ -504,13 +516,12 @@ def load_models_gpu(models, memory_required=0):
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM
async_kept_memory = (current_free_mem - inference_memory) / 1.3
async_kept_memory = int(max(0, async_kept_memory))
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory)
if vram_set_state == VRAMState.NO_VRAM:
async_kept_memory = 0
model_gpu_memory_when_using_cpu_swap = 0
loaded_model.model_load(async_kept_memory)
loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time