From 021428da260daa41c77173885ada504bd22b8ed1 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 15 Aug 2024 06:35:15 -0700 Subject: [PATCH] fix nf4 lora gives pure noise on some devices --- backend/patcher/base.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/backend/patcher/base.py b/backend/patcher/base.py index c50beaae..12e651c9 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -256,9 +256,6 @@ class ModelPatcher: except: raise ValueError(f"Wrong LoRA Key: {key}") - weight_original_device = weight.device - lora_computation_device = weight.device - if key not in self.backup: self.backup[key] = weight.to(device=self.offload_device) @@ -269,6 +266,8 @@ class ModelPatcher: assert weight.module is not None, 'BNB bad weight without parent layer!' bnb_layer = weight.module if weight.bnb_quantized: + weight_original_device = weight.device + if target_device is not None: assert target_device.type == 'cuda', 'BNB Must use CUDA!' weight = weight.to(target_device) @@ -277,14 +276,12 @@ class ModelPatcher: from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) + + if target_device is None: + weight = weight.to(device=weight_original_device) else: weight = weight.data - if target_device is None: - weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device)) - else: - weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) - gguf_cls, gguf_type, gguf_real_shape = None, None, None if hasattr(weight, 'is_gguf'): @@ -295,11 +292,13 @@ class ModelPatcher: weight = dequantize_tensor(weight) weight_original_dtype = weight.dtype - weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device)) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) - if target_device is None: - weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device)) + if target_device is not None: + weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) + + weight = weight.to(dtype=torch.float32) + + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: bnb_layer.reload_weight(weight)