From cb889470ba33722a89c3f625f972a795504abdc6 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:52:19 -0700 Subject: [PATCH] experimental LoRA support for NF4 Model method may change later depending on result quality --- backend/operations_bnb.py | 19 +++++++++++++++++++ backend/patcher/base.py | 20 +++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/backend/operations_bnb.py b/backend/operations_bnb.py index 0ce9db7f..7418806e 100644 --- a/backend/operations_bnb.py +++ b/backend/operations_bnb.py @@ -4,6 +4,7 @@ import torch import bitsandbytes as bnb from bitsandbytes.nn.modules import Params4bit, QuantState +from bitsandbytes.functional import dequantize_4bit def functional_linear_4bits(x, weight, bias): @@ -12,6 +13,10 @@ def functional_linear_4bits(x, weight, bias): return out +def functional_dequantize_4bit(weight): + return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type) + + def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState: if state is None: return None @@ -119,3 +124,17 @@ class ForgeLoader4Bit(torch.nn.Module): del self.dummy else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def reload_weight(self, weight): + self.weight = ForgeParams4bit( + weight, + requires_grad=False, + compress_statistics=self.weight.compress_statistics, + blocksize=self.weight.blocksize, + quant_type=self.weight.quant_type, + quant_storage=self.weight.quant_storage, + module=self, + bnb_quantized=False + ) + self.quant_state = self.weight.quant_state + return self diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 57e5be99..1b40e984 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -255,9 +255,23 @@ class ModelPatcher: if key not in self.backup: self.backup[key] = weight.to(device=self.offload_device) + bnb_layer = None + if operations.bnb_avaliable: if hasattr(weight, 'bnb_quantized'): - raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!') + assert weight.module is not None, 'BNB bad weight without parent layer!' + bnb_layer = weight.module + if weight.bnb_quantized: + if device_to is not None: + assert device_to.type == 'cuda', 'BNB Must use CUDA!' + weight = weight.to(device_to) + else: + weight = weight.cuda() + + from backend.operations_bnb import functional_dequantize_4bit + weight = functional_dequantize_4bit(weight) + else: + weight = weight.data to_args = dict(dtype=torch.float32) @@ -269,6 +283,10 @@ class ModelPatcher: out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) + if bnb_layer is not None: + bnb_layer.reload_weight(out_weight) + continue + utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False)) if device_to is not None: