experimental LoRA support for NF4 Model

method may change later depending on result quality
This commit is contained in:
layerdiffusion
2024-08-14 19:52:19 -07:00
parent 70a5acd8ad
commit cb889470ba
2 changed files with 38 additions and 1 deletions

View File

@@ -4,6 +4,7 @@ import torch
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
from bitsandbytes.functional import dequantize_4bit
def functional_linear_4bits(x, weight, bias):
@@ -12,6 +13,10 @@ def functional_linear_4bits(x, weight, bias):
return out
def functional_dequantize_4bit(weight):
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
if state is None:
return None
@@ -119,3 +124,17 @@ class ForgeLoader4Bit(torch.nn.Module):
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def reload_weight(self, weight):
self.weight = ForgeParams4bit(
weight,
requires_grad=False,
compress_statistics=self.weight.compress_statistics,
blocksize=self.weight.blocksize,
quant_type=self.weight.quant_type,
quant_storage=self.weight.quant_storage,
module=self,
bnb_quantized=False
)
self.quant_state = self.weight.quant_state
return self