diff --git a/backend/memory_management.py b/backend/memory_management.py index a6cc9578..308f3f3a 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -451,7 +451,9 @@ class LoadedModel: self.model_unload() raise e - if not do_not_need_cpu_swap: + if do_not_need_cpu_swap: + print('All loaded to GPU.') + else: gpu_modules, gpu_modules_only_extras, cpu_modules = build_module_profile(self.real_model, model_gpu_memory_when_using_cpu_swap) pin_memory = PIN_SHARED_MEMORY and is_device_cpu(self.model.offload_device) @@ -465,10 +467,6 @@ class LoadedModel: for m in cpu_modules: m.prev_parameters_manual_cast = m.parameters_manual_cast m.parameters_manual_cast = True - - if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': - m.to(self.device) # Quantize happens here - m.to(self.model.offload_device) if pin_memory: m._apply(lambda x: x.pin_memory()) @@ -477,10 +475,6 @@ class LoadedModel: for m in gpu_modules_only_extras: m.prev_parameters_manual_cast = m.parameters_manual_cast m.parameters_manual_cast = True - - if hasattr(m, 'weight') and m.weight is not None and hasattr(m.weight, 'bnb_quantized') and not m.weight.bnb_quantized and self.device.type == 'cuda': - m.to(self.device) # Quantize happens here - module_move(m, device=self.device, recursive=False, excluded_pattens=['weight']) if hasattr(m, 'weight') and m.weight is not None: if pin_memory: @@ -492,14 +486,15 @@ class LoadedModel: swap_flag = 'Shared' if PIN_SHARED_MEMORY else 'CPU' method_flag = 'asynchronous' if stream.should_use_stream() else 'blocked' - print(f"[Memory Management] Loaded to {swap_flag} Swap: {swap_counter / (1024 * 1024):.2f} MB ({method_flag} method)") - print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB") + print(f"{swap_flag} Swap Loaded ({method_flag} method): {swap_counter / (1024 * 1024):.2f} MB, GPU Loaded: {mem_counter / (1024 * 1024):.2f} MB") self.model_accelerated = True global signal_empty_cache signal_empty_cache = True + self.model.lora_loader.refresh(offload_device=self.model.offload_device) + if is_intel_xpu() and not args.disable_ipex_hijack: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) @@ -548,23 +543,23 @@ def unload_model_clones(model): def free_memory(memory_required, device, keep_loaded=[], free_all=False): if free_all: memory_required = 1e30 - print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ...") + print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="") else: - print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ...") + print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="") offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM unloaded_model = False for i in range(len(current_loaded_models) - 1, -1, -1): if not offload_everything: free_memory = get_free_memory(device) - print(f"[Unload] Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ") + print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="") if free_memory > memory_required: break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: m = current_loaded_models.pop(i) - print(f"[Unload] Unload model {m.model.model.__class__.__name__}") + print(f"Unload model {m.model.model.__class__.__name__} ", end="") m.model_unload() del m unloaded_model = True @@ -577,6 +572,9 @@ def free_memory(memory_required, device, keep_loaded=[], free_all=False): if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() + print('Done.') + return + def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory): maximum_memory_available = current_free_mem - inference_memory @@ -605,8 +603,6 @@ def load_models_gpu(models, memory_required=0): current_loaded_models.insert(0, current_loaded_models.pop(index)) models_already_loaded.append(loaded_model) else: - if hasattr(x, "model"): - print(f"To load target model {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -621,8 +617,6 @@ def load_models_gpu(models, memory_required=0): return - print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}") - total_memory_required = {} for loaded_model in models_to_load: unload_model_clones(loaded_model.model) @@ -648,10 +642,7 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() estimated_remaining_memory = current_free_mem - model_memory - inference_memory - print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB") - print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB") - print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB") - print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB") + print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM diff --git a/backend/operations_bnb.py b/backend/operations_bnb.py index 968f62a6..b45ab542 100644 --- a/backend/operations_bnb.py +++ b/backend/operations_bnb.py @@ -15,7 +15,20 @@ def functional_linear_4bits(x, weight, bias): def functional_dequantize_4bit(weight): - return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type) + if not weight.bnb_quantized: + return weight + + weight_original_device = weight.device + + if weight_original_device.type != 'cuda': + weight = weight.cuda() + + weight = dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type) + + if weight_original_device.type != 'cuda': + weight = weight.to(device=weight_original_device) + + return weight def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: @@ -140,7 +153,8 @@ class ForgeLoader4Bit(torch.nn.Module): super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def reload_weight(self, weight): - self.weight = ForgeParams4bit( + weight_original_device = weight.device + weight = ForgeParams4bit( weight, requires_grad=False, compress_statistics=self.weight.compress_statistics, @@ -149,4 +163,9 @@ class ForgeLoader4Bit(torch.nn.Module): quant_storage=self.weight.quant_storage, bnb_quantized=False ) + if weight_original_device.type == 'cuda': + weight = weight.to(weight_original_device) + else: + weight = weight.cuda().to(weight_original_device) + self.weight = weight return self diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 39c28c60..1592e3d0 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -223,10 +223,9 @@ class ModelPatcher: utils.set_attr_raw(self.model, k, item) - self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device) - if target_device is not None: self.model.to(target_device) + self.current_device = target_device return self.model diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index cb18c871..5d6083a2 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -264,6 +264,22 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t return weight +def get_parameter_devices(model): + parameter_devices = {} + for key, p in model.named_parameters(): + parameter_devices[key] = p.device + return parameter_devices + + +def set_parameter_devices(model, parameter_devices): + for key, device in parameter_devices.items(): + p = utils.get_attr(model, key) + if p.device != device: + p = utils.tensor2parameter(p.to(device=device)) + utils.set_attr_raw(model, key, p) + return model + + from backend import operations @@ -314,7 +330,7 @@ class LoraLoader: return list(p) @torch.inference_mode() - def refresh(self, target_device=None, offload_device=torch.device('cpu')): + def refresh(self, offload_device=torch.device('cpu')): if not self.dirty: return @@ -322,6 +338,12 @@ class LoraLoader: execution_start_time = time.perf_counter() + # Initialize + + memory_management.signal_empty_cache = True + + parameter_devices = get_parameter_devices(self.model) + # Restore for m in set(self.online_backup): @@ -338,26 +360,17 @@ class LoraLoader: self.backup = {} - if len(self.patches) > 0: - if self.online_mode: - print('Patching LoRA in on-the-fly.') - else: - print('Patching LoRA by precomputing model weights.') + set_parameter_devices(self.model, parameter_devices=parameter_devices) # Patch - memory_management.signal_empty_cache = True - - for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches): + for key, current_patches in self.patches.items(): try: parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") - if key not in self.backup: - self.backup[key] = weight.to(device=offload_device) - if self.online_mode: if not hasattr(parent_layer, 'forge_online_loras'): parent_layer.forge_online_loras = {} @@ -366,36 +379,15 @@ class LoraLoader: self.online_backup.append(parent_layer) continue + if key not in self.backup: + self.backup[key] = weight.to(device=offload_device) + bnb_layer = None - if operations.bnb_avaliable: - if hasattr(weight, 'bnb_quantized'): - bnb_layer = parent_layer - if weight.bnb_quantized: - weight_original_device = weight.device - - if target_device is not None: - assert target_device.type == 'cuda', 'BNB Must use CUDA!' - weight = weight.to(target_device) - else: - weight = weight.cuda() - - from backend.operations_bnb import functional_dequantize_4bit - weight = functional_dequantize_4bit(weight) - - if target_device is None: - weight = weight.to(device=weight_original_device) - else: - weight = weight.data - - if target_device is not None: - try: - weight = weight.to(device=target_device) - except: - print('Moving layer weight failed. Retrying by offloading models.') - self.model.to(device=offload_device) - memory_management.soft_empty_cache() - weight = weight.to(device=target_device) + if hasattr(weight, 'bnb_quantized') and operations.bnb_avaliable: + bnb_layer = parent_layer + from backend.operations_bnb import functional_dequantize_4bit + weight = functional_dequantize_4bit(weight) gguf_cls, gguf_type, gguf_real_shape = None, None, None @@ -409,8 +401,8 @@ class LoraLoader: try: weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) except: - print('Patching LoRA weights failed. Retrying by offloading models.') - self.model.to(device=offload_device) + print('Patching LoRA weights out of memory. Retrying by offloading models.') + set_parameter_devices(self.model, parameter_devices={k: offload_device for k in parameter_devices.keys()}) memory_management.soft_empty_cache() weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) @@ -436,9 +428,14 @@ class LoraLoader: # Time + set_parameter_devices(self.model, parameter_devices=parameter_devices) + moving_time = time.perf_counter() - execution_start_time - if moving_time > 0.1: - print(f'LoRA patching has taken {moving_time:.2f} seconds') + if len(self.patches) > 0: + if self.online_mode: + print(f'Patching LoRA on-the-fly in {moving_time:.2f} seconds.') + else: + print(f'Patching LoRA by precomputing model weights in {moving_time:.2f} seconds.') return