maintain loading related

1. revise model moving orders
2. less verbose printing
3. some misc minor speedups
4. some bnb related maintain
This commit is contained in:
layerdiffusion
2024-08-29 19:02:43 -07:00
parent c70fb38b0a
commit 95e16f7204
4 changed files with 77 additions and 71 deletions

View File

@@ -451,7 +451,9 @@ class LoadedModel:
self.model_unload()
raise e
if not do_not_need_cpu_swap:
if do_not_need_cpu_swap:
print('All loaded to GPU.')
else:
gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap)
pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device)
@@ -465,10 +467,6 @@ class LoadedModel:
for m in cpu_modules:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
m.to(self.device) # Quantize happens here
m.to(self.model.offload_device)
if pin_memory:
m._apply(lambda x: x.pin_memory())
@@ -477,10 +475,6 @@ class LoadedModel:
for m in gpu_modules_only_extras:
m.prev_parameters_manual_cast = m.parameters_manual_cast
m.parameters_manual_cast = True
if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda':
m.to(self.device) # Quantize happens here
module_move(m, device=self.device, recursive=False, excluded_pattens=['weight'])
if hasattr(m, 'weight') and m.weight is not None:
if pin_memory:
@@ -492,14 +486,15 @@ class LoadedModel:
swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU'
method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked'
print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)")
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB")
self.model_accelerated = True
global signal_empty_cache
signal_empty_cache = True
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
if is_intel_xpu() and not args.disable_ipex_hijack:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
@@ -548,23 +543,23 @@ def unload_model_clones(model):
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
if free_all:
memory_required = 1e30
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ...")
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
else:
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ...")
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
unloaded_model = False
for i in range(len(current_loaded_models) - 1, -1, -1):
if not offload_everything:
free_memory = get_free_memory(device)
print(f"[Unload] Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ")
print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
if free_memory > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
m = current_loaded_models.pop(i)
print(f"[Unload] Unload model {m.model.model.__class__.__name__}")
print(f"Unload model {m.model.model.__class__.__name__} ", end="")
m.model_unload()
del m
unloaded_model = True
@@ -577,6 +572,9 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
print('Done.')
return
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
maximum_memory_available = current_free_mem - inference_memory
@@ -605,8 +603,6 @@ def load_models_gpu(models, memory_required=0):
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"To load target model {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@@ -621,8 +617,6 @@ def load_models_gpu(models, memory_required=0):
return
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model)
@@ -648,10 +642,7 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB")
print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB")
print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB")
print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB")
print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM

View File

@@ -15,7 +15,20 @@ def functional_linear_4bits(x, weight, bias):
def functional_dequantize_4bit(weight):
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
if not weight.bnb_quantized:
return weight
weight_original_device = weight.device
if weight_original_device.type != 'cuda':
weight = weight.cuda()
weight = dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
if weight_original_device.type != 'cuda':
weight = weight.to(device=weight_original_device)
return weight
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
@@ -140,7 +153,8 @@ class ForgeLoader4Bit(torch.nn.Module):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def reload_weight(self, weight):
self.weight = ForgeParams4bit(
weight_original_device = weight.device
weight = ForgeParams4bit(
weight,
requires_grad=False,
compress_statistics=self.weight.compress_statistics,
@@ -149,4 +163,9 @@ class ForgeLoader4Bit(torch.nn.Module):
quant_storage=self.weight.quant_storage,
bnb_quantized=False
)
if weight_original_device.type == 'cuda':
weight = weight.to(weight_original_device)
else:
weight = weight.cuda().to(weight_original_device)
self.weight = weight
return self

View File

@@ -223,10 +223,9 @@ class ModelPatcher:
utils.set_attr_raw(self.model, k, item)
self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device)
if target_device is not None:
self.model.to(target_device)
self.current_device = target_device
return self.model

View File

@@ -264,6 +264,22 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
return weight
def get_parameter_devices(model):
parameter_devices = {}
for key, p in model.named_parameters():
parameter_devices[key] = p.device
return parameter_devices
def set_parameter_devices(model, parameter_devices):
for key, device in parameter_devices.items():
p = utils.get_attr(model, key)
if p.device != device:
p = utils.tensor2parameter(p.to(device=device))
utils.set_attr_raw(model, key, p)
return model
from backend import operations
@@ -314,7 +330,7 @@ class LoraLoader:
return list(p)
@torch.inference_mode()
def refresh(self, target_device=None, offload_device=torch.device('cpu')):
def refresh(self, offload_device=torch.device('cpu')):
if not self.dirty:
return
@@ -322,6 +338,12 @@ class LoraLoader:
execution_start_time = time.perf_counter()
# Initialize
memory_management.signal_empty_cache = True
parameter_devices = get_parameter_devices(self.model)
# Restore
for m in set(self.online_backup):
@@ -338,26 +360,17 @@ class LoraLoader:
self.backup = {}
if len(self.patches) > 0:
if self.online_mode:
print('Patching LoRA in on-the-fly.')
else:
print('Patching LoRA by precomputing model weights.')
set_parameter_devices(self.model, parameter_devices=parameter_devices)
# Patch
memory_management.signal_empty_cache = True
for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches):
for key, current_patches in self.patches.items():
try:
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
if self.online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}
@@ -366,36 +379,15 @@ class LoraLoader:
self.online_backup.append(parent_layer)
continue
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
bnb_layer = None
if operations.bnb_avaliable:
if hasattr(weight, 'bnb_quantized'):
bnb_layer = parent_layer
if weight.bnb_quantized:
weight_original_device = weight.device
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
else:
weight = weight.cuda()
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
if target_device is not None:
try:
weight = weight.to(device=target_device)
except:
print('Moving layer weight failed. Retrying by offloading models.')
self.model.to(device=offload_device)
memory_management.soft_empty_cache()
weight = weight.to(device=target_device)
if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable:
bnb_layer = parent_layer
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
@@ -409,8 +401,8 @@ class LoraLoader:
try:
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
except:
print('Patching LoRA weights failed. Retrying by offloading models.')
self.model.to(device=offload_device)
print('Patching LoRA weights out of memory. Retrying by offloading models.')
set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()})
memory_management.soft_empty_cache()
weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32)
@@ -436,9 +428,14 @@ class LoraLoader:
# Time
set_parameter_devices(self.model, parameter_devices=parameter_devices)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'LoRA patching has taken {moving_time:.2f} seconds')
if len(self.patches) > 0:
if self.online_mode:
print(f'Patching LoRA on-the-fly in {moving_time:.2f} seconds.')
else:
print(f'Patching LoRA by precomputing model weights in {moving_time:.2f} seconds.')
return