fix nf4 lora gives pure noise on some devices

This commit is contained in:
layerdiffusion
2024-08-15 06:35:15 -07:00
parent 3d751eb69f
commit 021428da26

View File

@@ -256,9 +256,6 @@ class ModelPatcher:
except:
raise ValueError(f"Wrong LoRA Key: {key}")
weight_original_device = weight.device
lora_computation_device = weight.device
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device)
@@ -269,6 +266,8 @@ class ModelPatcher:
assert weight.module is not None, 'BNB bad weight without parent layer!'
bnb_layer = weight.module
if weight.bnb_quantized:
weight_original_device = weight.device
if target_device is not None:
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
weight = weight.to(target_device)
@@ -277,14 +276,12 @@ class ModelPatcher:
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
if target_device is None:
weight = weight.to(device=weight_original_device)
else:
weight = weight.data
if target_device is None:
weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device))
else:
weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
gguf_cls, gguf_type, gguf_real_shape = None, None, None
if hasattr(weight, 'is_gguf'):
@@ -295,11 +292,13 @@ class ModelPatcher:
weight = dequantize_tensor(weight)
weight_original_dtype = weight.dtype
weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device))
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if target_device is None:
weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device))
if target_device is not None:
weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
weight = weight.to(dtype=torch.float32)
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
if bnb_layer is not None:
bnb_layer.reload_weight(weight)