mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 08:59:58 +00:00
experimental LoRA support for NF4 Model
method may change later depending on result quality
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
|
||||
|
||||
def functional_linear_4bits(x, weight, bias):
|
||||
@@ -12,6 +13,10 @@ def functional_linear_4bits(x, weight, bias):
|
||||
return out
|
||||
|
||||
|
||||
def functional_dequantize_4bit(weight):
|
||||
return dequantize_4bit(weight, quant_state=weight.quant_state, blocksize=weight.blocksize, quant_type=weight.quant_type)
|
||||
|
||||
|
||||
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState:
|
||||
if state is None:
|
||||
return None
|
||||
@@ -119,3 +124,17 @@ class ForgeLoader4Bit(torch.nn.Module):
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def reload_weight(self, weight):
|
||||
self.weight = ForgeParams4bit(
|
||||
weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=self.weight.compress_statistics,
|
||||
blocksize=self.weight.blocksize,
|
||||
quant_type=self.weight.quant_type,
|
||||
quant_storage=self.weight.quant_storage,
|
||||
module=self,
|
||||
bnb_quantized=False
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
return self
|
||||
|
||||
@@ -255,9 +255,23 @@ class ModelPatcher:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device)
|
||||
|
||||
bnb_layer = None
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
raise NotImplementedError('LoRAs for NF4/FP4 models are under construction and not available now.\nSorry for the inconvenience!')
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
if weight.bnb_quantized:
|
||||
if device_to is not None:
|
||||
assert device_to.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(device_to)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
to_args = dict(dtype=torch.float32)
|
||||
|
||||
@@ -269,6 +283,10 @@ class ModelPatcher:
|
||||
|
||||
out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype)
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(out_weight)
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||
|
||||
if device_to is not None:
|
||||
|
||||
Reference in New Issue
Block a user