Rework unload system

Previous repeated loading (on cn or other extensions) is fixed. ControlNet saves about 0.7 to 1.1 seconds on my two device when batch count > 1. 

8GB VRAM can use SDXL at resolution 6144x6144 now, out of the box, without tiled diffusion or other things. 

(the max resolution on Automatic1111 txt2img UI is 2048 but one can highres fix to try 6144 or even 8192)
This commit is contained in:
lllyasviel
2024-02-23 12:58:09 -08:00
committed by GitHub
parent 19473b1a26
commit df12dde12e

View File

@@ -268,17 +268,23 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = [] current_loaded_models = []
def module_size(module): def module_size(module, exclude_device=None):
module_mem = 0 module_mem = 0
sd = module.state_dict() sd = module.state_dict()
for k in sd: for k in sd:
t = sd[k] t = sd[k]
if exclude_device is not None:
if t.device == exclude_device:
continue
module_mem += t.nelement() * t.element_size() module_mem += t.nelement() * t.element_size()
return module_mem return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model, memory_required):
self.model = model self.model = model
self.memory_required = memory_required
self.model_accelerated = False self.model_accelerated = False
self.device = model.load_device self.device = model.load_device
@@ -286,10 +292,7 @@ class LoadedModel:
return self.model.model_size() return self.model.model_size()
def model_memory_required(self, device): def model_memory_required(self, device):
if device == self.model.current_device: return module_size(self.model.model, exclude_device=device)
return 0
else:
return self.model_memory()
def model_load(self, async_kept_memory=-1): def model_load(self, async_kept_memory=-1):
patch_model_to = None patch_model_to = None
@@ -324,7 +327,9 @@ class LoadedModel:
real_kept_memory += module_mem real_kept_memory += module_mem
else: else:
real_async_memory += module_mem real_async_memory += module_mem
m._apply(lambda x: x.pin_memory()) m.to(self.model.offload_device)
if is_device_cpu(self.model.offload_device):
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"): elif hasattr(m, "weight"):
m.to(self.device) m.to(self.device)
mem_counter += module_size(m) mem_counter += module_size(m)
@@ -339,7 +344,7 @@ class LoadedModel:
return self.real_model return self.real_model
def model_unload(self): def model_unload(self, avoid_model_moving=False):
if self.model_accelerated: if self.model_accelerated:
for m in self.real_model.modules(): for m in self.real_model.modules():
if hasattr(m, "prev_ldm_patched_cast_weights"): if hasattr(m, "prev_ldm_patched_cast_weights"):
@@ -348,11 +353,14 @@ class LoadedModel:
self.model_accelerated = False self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) if avoid_model_moving:
self.model.model_patches_to(self.model.offload_device) self.model.unpatch_model()
else:
self.model.unpatch_model(self.model.offload_device)
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model and self.memory_required == other.memory_required
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
@@ -363,9 +371,10 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model): if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload: for i in to_unload:
print("unload clone", i) current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False unloaded_model = False
@@ -400,7 +409,7 @@ def load_models_gpu(models, memory_required=0):
models_to_load = [] models_to_load = []
models_already_loaded = [] models_already_loaded = []
for x in models: for x in models:
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x, memory_required=memory_required)
if loaded_model in current_loaded_models: if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model) index = current_loaded_models.index(loaded_model)