diff --git a/backend/patcher/base.py b/backend/patcher/base.py index c50beaae..12e651c9 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -256,9 +256,6 @@ class ModelPatcher: except: raise ValueError(f"Wrong LoRA Key: {key}") - weight_original_device = weight.device - lora_computation_device = weight.device - if key not in self.backup: self.backup[key] = weight.to(device=self.offload_device) @@ -269,6 +266,8 @@ class ModelPatcher: assert weight.module is not None, 'BNB bad weight without parent layer!' bnb_layer = weight.module if weight.bnb_quantized: + weight_original_device = weight.device + if target_device is not None: assert target_device.type == 'cuda', 'BNB Must use CUDA!' weight = weight.to(target_device) @@ -277,14 +276,12 @@ class ModelPatcher: from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) + + if target_device is None: + weight = weight.to(device=weight_original_device) else: weight = weight.data - if target_device is None: - weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device)) - else: - weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) - gguf_cls, gguf_type, gguf_real_shape = None, None, None if hasattr(weight, 'is_gguf'): @@ -295,11 +292,13 @@ class ModelPatcher: weight = dequantize_tensor(weight) weight_original_dtype = weight.dtype - weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device)) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) - if target_device is None: - weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device)) + if target_device is not None: + weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) + + weight = weight.to(dtype=torch.float32) + + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: bnb_layer.reload_weight(weight)