mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-02 19:49:48 +00:00
maintain loading related
1. revise model moving orders 2. less verbose printing 3. some misc minor speedups 4. some bnb related maintain
This commit is contained in:
@@ -451,7 +451,9 @@ class LoadedModel:
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if not do_not_need_cpu_swap:
|
||||
if do_not_need_cpu_swap:
|
||||
print('All loaded to GPU.')
|
||||
else:
|
||||
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
|
||||
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
|
||||
|
||||
@@ -465,10 +467,6 @@ class LoadedModel:
|
||||
for m in cpu_modules:
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
|
||||
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
||||
m.to(self.device) # Quantize happens here
|
||||
|
||||
m.to(self.model.offload_device)
|
||||
if pin_memory:
|
||||
m._apply(lambda x: x.pin_memory())
|
||||
@@ -477,10 +475,6 @@ class LoadedModel:
|
||||
for m in gpu_modules_only_extras:
|
||||
m.prev_parameters_manual_cast = m.parameters_manual_cast
|
||||
m.parameters_manual_cast = True
|
||||
|
||||
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
|
||||
m.to(self.device) # Quantize happens here
|
||||
|
||||
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
if pin_memory:
|
||||
@@ -492,14 +486,15 @@ class LoadedModel:
|
||||
|
||||
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
|
||||
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
|
||||
print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
|
||||
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
|
||||
print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB")
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
global signal_empty_cache
|
||||
signal_empty_cache = True
|
||||
|
||||
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
|
||||
@@ -548,23 +543,23 @@ def unload_model_clones(model):
|
||||
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
||||
if free_all:
|
||||
memory_required = 1e30
|
||||
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ...")
|
||||
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
||||
else:
|
||||
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ...")
|
||||
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
||||
|
||||
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
|
||||
unloaded_model = False
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
if not offload_everything:
|
||||
free_memory = get_free_memory(device)
|
||||
print(f"[Unload] Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ")
|
||||
print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
|
||||
if free_memory > memory_required:
|
||||
break
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
m = current_loaded_models.pop(i)
|
||||
print(f"[Unload] Unload model {m.model.model.__class__.__name__}")
|
||||
print(f"Unload model {m.model.model.__class__.__name__} ", end="")
|
||||
m.model_unload()
|
||||
del m
|
||||
unloaded_model = True
|
||||
@@ -577,6 +572,9 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
|
||||
print('Done.')
|
||||
return
|
||||
|
||||
|
||||
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
||||
maximum_memory_available = current_free_mem - inference_memory
|
||||
@@ -605,8 +603,6 @@ def load_models_gpu(models, memory_required=0):
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
print(f"To load target model {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
@@ -621,8 +617,6 @@ def load_models_gpu(models, memory_required=0):
|
||||
|
||||
return
|
||||
|
||||
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model)
|
||||
@@ -648,10 +642,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
inference_memory = minimum_inference_memory()
|
||||
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
|
||||
|
||||
print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB")
|
||||
print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB")
|
||||
print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB")
|
||||
print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
||||
|
||||
if estimated_remaining_memory < 0:
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
@@ -15,7 +15,20 @@ def functional_linear_4bits(x, weight, bias):
|
||||
|
||||
|
||||
def functional_dequantize_4bit(weight):
|
||||
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
|
||||
if not weight.bnb_quantized:
|
||||
return weight
|
||||
|
||||
weight_original_device = weight.device
|
||||
|
||||
if weight_original_device.type != 'cuda':
|
||||
weight = weight.cuda()
|
||||
|
||||
weight = dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
|
||||
|
||||
if weight_original_device.type != 'cuda':
|
||||
weight = weight.to(device=weight_original_device)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
|
||||
@@ -140,7 +153,8 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def reload_weight(self, weight):
|
||||
self.weight = ForgeParams4bit(
|
||||
weight_original_device = weight.device
|
||||
weight = ForgeParams4bit(
|
||||
weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
@@ -149,4 +163,9 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
quant_storage=self.weight.quant_storage,
|
||||
bnb_quantized=False
|
||||
)
|
||||
if weight_original_device.type == 'cuda':
|
||||
weight = weight.to(weight_original_device)
|
||||
else:
|
||||
weight = weight.cuda().to(weight_original_device)
|
||||
self.weight = weight
|
||||
return self
|
||||
|
||||
@@ -223,10 +223,9 @@ class ModelPatcher:
|
||||
|
||||
utils.set_attr_raw(self.model, k, item)
|
||||
|
||||
self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device)
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
@@ -264,6 +264,22 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
return weight
|
||||
|
||||
|
||||
def get_parameter_devices(model):
|
||||
parameter_devices = {}
|
||||
for key, p in model.named_parameters():
|
||||
parameter_devices[key] = p.device
|
||||
return parameter_devices
|
||||
|
||||
|
||||
def set_parameter_devices(model, parameter_devices):
|
||||
for key, device in parameter_devices.items():
|
||||
p = utils.get_attr(model, key)
|
||||
if p.device != device:
|
||||
p = utils.tensor2parameter(p.to(device=device))
|
||||
utils.set_attr_raw(model, key, p)
|
||||
return model
|
||||
|
||||
|
||||
from backend import operations
|
||||
|
||||
|
||||
@@ -314,7 +330,7 @@ class LoraLoader:
|
||||
return list(p)
|
||||
|
||||
@torch.inference_mode()
|
||||
def refresh(self, target_device=None, offload_device=torch.device('cpu')):
|
||||
def refresh(self, offload_device=torch.device('cpu')):
|
||||
if not self.dirty:
|
||||
return
|
||||
|
||||
@@ -322,6 +338,12 @@ class LoraLoader:
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
# Initialize
|
||||
|
||||
memory_management.signal_empty_cache = True
|
||||
|
||||
parameter_devices = get_parameter_devices(self.model)
|
||||
|
||||
# Restore
|
||||
|
||||
for m in set(self.online_backup):
|
||||
@@ -338,26 +360,17 @@ class LoraLoader:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if len(self.patches) > 0:
|
||||
if self.online_mode:
|
||||
print('Patching LoRA in on-the-fly.')
|
||||
else:
|
||||
print('Patching LoRA by precomputing model weights.')
|
||||
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||
|
||||
# Patch
|
||||
|
||||
memory_management.signal_empty_cache = True
|
||||
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
|
||||
for key, current_patches in self.patches.items():
|
||||
try:
|
||||
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
if self.online_mode:
|
||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||
parent_layer.forge_online_loras = {}
|
||||
@@ -366,36 +379,15 @@ class LoraLoader:
|
||||
self.online_backup.append(parent_layer)
|
||||
continue
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
bnb_layer = None
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
bnb_layer = parent_layer
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
if target_device is not None:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
if target_device is not None:
|
||||
try:
|
||||
weight = weight.to(device=target_device)
|
||||
except:
|
||||
print('Moving layer weight failed. Retrying by offloading models.')
|
||||
self.model.to(device=offload_device)
|
||||
memory_management.soft_empty_cache()
|
||||
weight = weight.to(device=target_device)
|
||||
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
|
||||
bnb_layer = parent_layer
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
|
||||
@@ -409,8 +401,8 @@ class LoraLoader:
|
||||
try:
|
||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||
except:
|
||||
print('Patching LoRA weights failed. Retrying by offloading models.')
|
||||
self.model.to(device=offload_device)
|
||||
print('Patching LoRA weights out of memory. Retrying by offloading models.')
|
||||
set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
|
||||
memory_management.soft_empty_cache()
|
||||
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
|
||||
|
||||
@@ -436,9 +428,14 @@ class LoraLoader:
|
||||
|
||||
# Time
|
||||
|
||||
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
|
||||
if moving_time > 0.1:
|
||||
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
||||
if len(self.patches) > 0:
|
||||
if self.online_mode:
|
||||
print(f'Patching LoRA on-the-fly in {moving_time:.2f} seconds.')
|
||||
else:
|
||||
print(f'Patching LoRA by precomputing model weights in {moving_time:.2f} seconds.')
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user