mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
maintain loading related
1. revise model moving orders 2. less verbose printing 3. some misc minor speedups 4. some bnb related maintain
This commit is contained in:
@@ -451,7 +451,9 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if not do_not_need_cpu_swap:
|
if do_not_need_cpu_swap:
|
||||||
|
print('All loaded to GPU.')
|
||||||
|
else:
|
||||||
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
|
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
|
||||||
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
|
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
|
||||||
|
|
||||||
@@ -465,10 +467,6 @@ class LoadedModel:
|
|||||||
for m in cpu_modules:
|
for m in cpu_modules:
|
||||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||||
m.parameters_manual_cast = True
|
m.parameters_manual_cast = True
|
||||||
|
|
||||||
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
|
||||||
m.to(self.device) # Quantize happens here
|
|
||||||
|
|
||||||
m.to(self.model.offload_device)
|
m.to(self.model.offload_device)
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
m._apply(lambda x: x.pin_memory())
|
m._apply(lambda x: x.pin_memory())
|
||||||
@@ -477,10 +475,6 @@ class LoadedModel:
|
|||||||
for m in gpu_modules_only_extras:
|
for m in gpu_modules_only_extras:
|
||||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||||
m.parameters_manual_cast = True
|
m.parameters_manual_cast = True
|
||||||
|
|
||||||
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
|
||||||
m.to(self.device) # Quantize happens here
|
|
||||||
|
|
||||||
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
|
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
|
||||||
if hasattr(m, 'weight') and m.weight is not None:
|
if hasattr(m, 'weight') and m.weight is not None:
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
@@ -492,14 +486,15 @@ class LoadedModel:
|
|||||||
|
|
||||||
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
||||||
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
||||||
print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
|
print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB")
|
||||||
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
|
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
global signal_empty_cache
|
global signal_empty_cache
|
||||||
signal_empty_cache = True
|
signal_empty_cache = True
|
||||||
|
|
||||||
|
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||||
|
|
||||||
@@ -548,23 +543,23 @@ def unload_model_clones(model):
|
|||||||
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
||||||
if free_all:
|
if free_all:
|
||||||
memory_required = 1e30
|
memory_required = 1e30
|
||||||
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ...")
|
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
||||||
else:
|
else:
|
||||||
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ...")
|
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
||||||
|
|
||||||
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
|
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
|
||||||
unloaded_model = False
|
unloaded_model = False
|
||||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
if not offload_everything:
|
if not offload_everything:
|
||||||
free_memory = get_free_memory(device)
|
free_memory = get_free_memory(device)
|
||||||
print(f"[Unload] Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ")
|
print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
|
||||||
if free_memory > memory_required:
|
if free_memory > memory_required:
|
||||||
break
|
break
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
m = current_loaded_models.pop(i)
|
m = current_loaded_models.pop(i)
|
||||||
print(f"[Unload] Unload model {m.model.model.__class__.__name__}")
|
print(f"Unload model {m.model.model.__class__.__name__} ", end="")
|
||||||
m.model_unload()
|
m.model_unload()
|
||||||
del m
|
del m
|
||||||
unloaded_model = True
|
unloaded_model = True
|
||||||
@@ -577,6 +572,9 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
|||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
|
print('Done.')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
||||||
maximum_memory_available = current_free_mem - inference_memory
|
maximum_memory_available = current_free_mem - inference_memory
|
||||||
@@ -605,8 +603,6 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||||
models_already_loaded.append(loaded_model)
|
models_already_loaded.append(loaded_model)
|
||||||
else:
|
else:
|
||||||
if hasattr(x, "model"):
|
|
||||||
print(f"To load target model {x.model.__class__.__name__}")
|
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
@@ -621,8 +617,6 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
unload_model_clones(loaded_model.model)
|
unload_model_clones(loaded_model.model)
|
||||||
@@ -648,10 +642,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
|
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
|
||||||
|
|
||||||
print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB")
|
print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
||||||
print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB")
|
|
||||||
print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB")
|
|
||||||
|
|
||||||
if estimated_remaining_memory < 0:
|
if estimated_remaining_memory < 0:
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
|||||||
@@ -15,7 +15,20 @@ def functional_linear_4bits(x, weight, bias):
|
|||||||
|
|
||||||
|
|
||||||
def functional_dequantize_4bit(weight):
|
def functional_dequantize_4bit(weight):
|
||||||
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
|
if not weight.bnb_quantized:
|
||||||
|
return weight
|
||||||
|
|
||||||
|
weight_original_device = weight.device
|
||||||
|
|
||||||
|
if weight_original_device.type != 'cuda':
|
||||||
|
weight = weight.cuda()
|
||||||
|
|
||||||
|
weight = dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
|
||||||
|
|
||||||
|
if weight_original_device.type != 'cuda':
|
||||||
|
weight = weight.to(device=weight_original_device)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
|
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
|
||||||
@@ -140,7 +153,8 @@ class ForgeLoader4Bit(torch.nn.Module):
|
|||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
def reload_weight(self, weight):
|
def reload_weight(self, weight):
|
||||||
self.weight = ForgeParams4bit(
|
weight_original_device = weight.device
|
||||||
|
weight = ForgeParams4bit(
|
||||||
weight,
|
weight,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
compress_statistics=self.weight.compress_statistics,
|
compress_statistics=self.weight.compress_statistics,
|
||||||
@@ -149,4 +163,9 @@ class ForgeLoader4Bit(torch.nn.Module):
|
|||||||
quant_storage=self.weight.quant_storage,
|
quant_storage=self.weight.quant_storage,
|
||||||
bnb_quantized=False
|
bnb_quantized=False
|
||||||
)
|
)
|
||||||
|
if weight_original_device.type == 'cuda':
|
||||||
|
weight = weight.to(weight_original_device)
|
||||||
|
else:
|
||||||
|
weight = weight.cuda().to(weight_original_device)
|
||||||
|
self.weight = weight
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -223,10 +223,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
utils.set_attr_raw(self.model, k, item)
|
utils.set_attr_raw(self.model, k, item)
|
||||||
|
|
||||||
self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device)
|
|
||||||
|
|
||||||
if target_device is not None:
|
if target_device is not None:
|
||||||
self.model.to(target_device)
|
self.model.to(target_device)
|
||||||
|
self.current_device = target_device
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|||||||
@@ -264,6 +264,22 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_devices(model):
|
||||||
|
parameter_devices = {}
|
||||||
|
for key, p in model.named_parameters():
|
||||||
|
parameter_devices[key] = p.device
|
||||||
|
return parameter_devices
|
||||||
|
|
||||||
|
|
||||||
|
def set_parameter_devices(model, parameter_devices):
|
||||||
|
for key, device in parameter_devices.items():
|
||||||
|
p = utils.get_attr(model, key)
|
||||||
|
if p.device != device:
|
||||||
|
p = utils.tensor2parameter(p.to(device=device))
|
||||||
|
utils.set_attr_raw(model, key, p)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
from backend import operations
|
from backend import operations
|
||||||
|
|
||||||
|
|
||||||
@@ -314,7 +330,7 @@ class LoraLoader:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh(self, target_device=None, offload_device=torch.device('cpu')):
|
def refresh(self, offload_device=torch.device('cpu')):
|
||||||
if not self.dirty:
|
if not self.dirty:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -322,6 +338,12 @@ class LoraLoader:
|
|||||||
|
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Initialize
|
||||||
|
|
||||||
|
memory_management.signal_empty_cache = True
|
||||||
|
|
||||||
|
parameter_devices = get_parameter_devices(self.model)
|
||||||
|
|
||||||
# Restore
|
# Restore
|
||||||
|
|
||||||
for m in set(self.online_backup):
|
for m in set(self.online_backup):
|
||||||
@@ -338,26 +360,17 @@ class LoraLoader:
|
|||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
if len(self.patches) > 0:
|
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||||
if self.online_mode:
|
|
||||||
print('Patching LoRA in on-the-fly.')
|
|
||||||
else:
|
|
||||||
print('Patching LoRA by precomputing model weights.')
|
|
||||||
|
|
||||||
# Patch
|
# Patch
|
||||||
|
|
||||||
memory_management.signal_empty_cache = True
|
for key, current_patches in self.patches.items():
|
||||||
|
|
||||||
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
|
|
||||||
try:
|
try:
|
||||||
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
||||||
assert isinstance(weight, torch.nn.Parameter)
|
assert isinstance(weight, torch.nn.Parameter)
|
||||||
except:
|
except:
|
||||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||||
|
|
||||||
if key not in self.backup:
|
|
||||||
self.backup[key] = weight.to(device=offload_device)
|
|
||||||
|
|
||||||
if self.online_mode:
|
if self.online_mode:
|
||||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||||
parent_layer.forge_online_loras = {}
|
parent_layer.forge_online_loras = {}
|
||||||
@@ -366,36 +379,15 @@ class LoraLoader:
|
|||||||
self.online_backup.append(parent_layer)
|
self.online_backup.append(parent_layer)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if key not in self.backup:
|
||||||
|
self.backup[key] = weight.to(device=offload_device)
|
||||||
|
|
||||||
bnb_layer = None
|
bnb_layer = None
|
||||||
|
|
||||||
if operations.bnb_avaliable:
|
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
|
||||||
if hasattr(weight, 'bnb_quantized'):
|
bnb_layer = parent_layer
|
||||||
bnb_layer = parent_layer
|
from backend.operations_bnb import functional_dequantize_4bit
|
||||||
if weight.bnb_quantized:
|
weight = functional_dequantize_4bit(weight)
|
||||||
weight_original_device = weight.device
|
|
||||||
|
|
||||||
if target_device is not None:
|
|
||||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
|
||||||
weight = weight.to(target_device)
|
|
||||||
else:
|
|
||||||
weight = weight.cuda()
|
|
||||||
|
|
||||||
from backend.operations_bnb import functional_dequantize_4bit
|
|
||||||
weight = functional_dequantize_4bit(weight)
|
|
||||||
|
|
||||||
if target_device is None:
|
|
||||||
weight = weight.to(device=weight_original_device)
|
|
||||||
else:
|
|
||||||
weight = weight.data
|
|
||||||
|
|
||||||
if target_device is not None:
|
|
||||||
try:
|
|
||||||
weight = weight.to(device=target_device)
|
|
||||||
except:
|
|
||||||
print('Moving layer weight failed. Retrying by offloading models.')
|
|
||||||
self.model.to(device=offload_device)
|
|
||||||
memory_management.soft_empty_cache()
|
|
||||||
weight = weight.to(device=target_device)
|
|
||||||
|
|
||||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||||
|
|
||||||
@@ -409,8 +401,8 @@ class LoraLoader:
|
|||||||
try:
|
try:
|
||||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||||
except:
|
except:
|
||||||
print('Patching LoRA weights failed. Retrying by offloading models.')
|
print('Patching LoRA weights out of memory. Retrying by offloading models.')
|
||||||
self.model.to(device=offload_device)
|
set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
|
||||||
memory_management.soft_empty_cache()
|
memory_management.soft_empty_cache()
|
||||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||||
|
|
||||||
@@ -436,9 +428,14 @@ class LoraLoader:
|
|||||||
|
|
||||||
# Time
|
# Time
|
||||||
|
|
||||||
|
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||||
|
|
||||||
moving_time = time.perf_counter() - execution_start_time
|
moving_time = time.perf_counter() - execution_start_time
|
||||||
|
|
||||||
if moving_time > 0.1:
|
if len(self.patches) > 0:
|
||||||
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
if self.online_mode:
|
||||||
|
print(f'Patching LoRA on-the-fly in {moving_time:.2f} seconds.')
|
||||||
|
else:
|
||||||
|
print(f'Patching LoRA by precomputing model weights in {moving_time:.2f} seconds.')
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user