diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 1b40e984..dbeed0fa 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -273,15 +273,15 @@ class ModelPatcher: else: weight = weight.data + weight_original_dtype = weight.dtype to_args = dict(dtype=torch.float32) if device_to is not None: to_args['device'] = device_to to_args['non_blocking'] = memory_management.device_supports_non_blocking(device_to) - temp_weight = weight.to(**to_args) - - out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) + weight = weight.to(**to_args) + out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) if bnb_layer is not None: bnb_layer.reload_weight(out_weight)