mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-04 06:59:59 +00:00
Support LoRAs for Q8/Q5/Q4 GGUF Models
what a crazy night of math
This commit is contained in:
@@ -2,34 +2,27 @@ import gguf
|
||||
import torch
|
||||
|
||||
|
||||
quants_mapping = {
|
||||
gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
|
||||
}
|
||||
# def functional_quantize_gguf(weight):
|
||||
# gguf_cls = weight.gguf_cls
|
||||
# gguf_cls.en
|
||||
|
||||
|
||||
def functional_linear_gguf(x, weight, bias=None):
|
||||
target_dtype = x.dtype
|
||||
weight = dequantize_tensor(weight, target_dtype)
|
||||
bias = dequantize_tensor(bias, target_dtype)
|
||||
weight = dequantize_tensor(weight).to(target_dtype)
|
||||
bias = dequantize_tensor(bias).to(target_dtype)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def dequantize_tensor(tensor, target_dtype=torch.float16):
|
||||
def dequantize_tensor(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
data = torch.tensor(tensor.data)
|
||||
gguf_type = tensor.gguf_type
|
||||
gguf_cls = tensor.gguf_cls
|
||||
gguf_real_shape = tensor.gguf_real_shape
|
||||
|
||||
if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]:
|
||||
return data.to(target_dtype)
|
||||
if gguf_cls is None:
|
||||
return data
|
||||
|
||||
if gguf_type not in quants_mapping:
|
||||
raise NotImplementedError(f'Quant type {gguf_type} not implemented!')
|
||||
|
||||
quant_cls = quants_mapping.get(gguf_type)
|
||||
|
||||
return quant_cls.dequantize_pytorch(data, gguf_real_shape).to(target_dtype)
|
||||
return gguf_cls.dequantize_pytorch(data, gguf_real_shape)
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
# are from Forge, implemented from scratch (after forge-v1.0.1), and may have
|
||||
# certain level of differences.
|
||||
|
||||
import time
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management, utils, operations
|
||||
from backend.patcher.lora import merge_lora_to_model_weight
|
||||
|
||||
@@ -237,6 +239,8 @@ class ModelPatcher:
|
||||
return sd
|
||||
|
||||
def forge_patch_model(self, target_device=None):
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
for k, item in self.object_patches.items():
|
||||
old = utils.get_attr(self.model, k)
|
||||
|
||||
@@ -245,13 +249,16 @@ class ModelPatcher:
|
||||
|
||||
utils.set_attr_raw(self.model, k, item)
|
||||
|
||||
for key, current_patches in self.patches.items():
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs to Diffusion Model') if len(self.patches) > 0 else self.patches):
|
||||
try:
|
||||
weight = utils.get_attr(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
weight_original_device = weight.device
|
||||
lora_computation_device = weight.device
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device)
|
||||
|
||||
@@ -262,8 +269,6 @@ class ModelPatcher:
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
if target_device is not None:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
@@ -272,35 +277,56 @@ class ModelPatcher:
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
if target_device is None:
|
||||
weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device))
|
||||
else:
|
||||
weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device))
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
|
||||
if hasattr(weight, 'is_gguf'):
|
||||
raise NotImplementedError('LoRAs for GGUF model are under construction!')
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
gguf_cls = weight.gguf_cls
|
||||
gguf_type = weight.gguf_type
|
||||
gguf_real_shape = weight.gguf_real_shape
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
to_args = dict(dtype=torch.float32)
|
||||
weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device))
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
if target_device is not None:
|
||||
to_args['device'] = target_device
|
||||
to_args['non_blocking'] = memory_management.device_supports_non_blocking(target_device)
|
||||
|
||||
weight = weight.to(**to_args)
|
||||
out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device))
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(out_weight)
|
||||
bnb_layer.reload_weight(weight)
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False))
|
||||
if gguf_cls is not None:
|
||||
from backend.utils import ParameterGGUF
|
||||
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
||||
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
|
||||
data=weight,
|
||||
gguf_type=gguf_type,
|
||||
gguf_cls=gguf_cls,
|
||||
gguf_real_shape=gguf_real_shape
|
||||
))
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
|
||||
if moving_time > 0.1:
|
||||
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
||||
|
||||
return self.model
|
||||
|
||||
def forge_unpatch_model(self, target_device=None):
|
||||
|
||||
@@ -6,6 +6,13 @@ import safetensors.torch
|
||||
import backend.misc.checkpoint_pickle
|
||||
|
||||
|
||||
quants_mapping = {
|
||||
gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
|
||||
}
|
||||
|
||||
|
||||
class ParameterGGUF(torch.nn.Parameter):
|
||||
def __init__(self, tensor=None, requires_grad=False, no_init=False):
|
||||
super().__init__()
|
||||
@@ -16,6 +23,7 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
|
||||
self.gguf_type = tensor.tensor_type
|
||||
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
|
||||
self.gguf_cls = quants_mapping.get(self.gguf_type, None)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -28,6 +36,15 @@ class ParameterGGUF(torch.nn.Parameter):
|
||||
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
new.gguf_cls = self.gguf_cls
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape):
|
||||
new = ParameterGGUF(data, no_init=True)
|
||||
new.gguf_type = gguf_type
|
||||
new.gguf_real_shape = gguf_real_shape
|
||||
new.gguf_cls = gguf_cls
|
||||
return new
|
||||
|
||||
|
||||
|
||||
@@ -324,7 +324,7 @@ class UiSettings:
|
||||
)
|
||||
|
||||
def button_set_checkpoint_change(value, dummy):
|
||||
return value, opts.dumpjson()
|
||||
return value.split(' [')[0], opts.dumpjson()
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
|
||||
82
packages_3rdparty/gguf/quants.py
vendored
82
packages_3rdparty/gguf/quants.py
vendored
@@ -125,8 +125,17 @@ class __Quant(ABC):
|
||||
cls.grid = grid.reshape((1, 1, *cls.grid_shape))
|
||||
|
||||
@classmethod
|
||||
def quantize_pytorch(cls, data: torch.Tensor) -> torch.Tensor:
|
||||
return cls.quantize_blocks_pytorch(data)
|
||||
def quantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
original_shape = [x for x in original_shape]
|
||||
original_shape[-1] = -1
|
||||
original_shape = tuple(original_shape)
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
blocks = data.reshape(-1, block_size)
|
||||
blocks = cls.quantize_blocks_pytorch(blocks, block_size, type_size)
|
||||
return blocks.reshape(original_shape)
|
||||
|
||||
@classmethod
|
||||
def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
|
||||
@@ -145,7 +154,7 @@ class __Quant(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def quantize_blocks_pytorch(cls, blocks) -> torch.Tensor:
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -287,6 +296,27 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return d * qs
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
imax = torch.abs(blocks).argmax(dim=-1, keepdim=True)
|
||||
max_vals = torch.gather(blocks, -1, imax)
|
||||
|
||||
d = max_vals / -8
|
||||
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1.0 / d)
|
||||
|
||||
qs = torch.trunc((blocks * id) + 8.5).clip(0, 15).to(torch.uint8)
|
||||
|
||||
qs = qs.reshape((n_blocks, 2, block_size // 2))
|
||||
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
|
||||
|
||||
d = d.to(torch.float16).view(torch.uint8)
|
||||
|
||||
return torch.cat([d, qs], dim=-1)
|
||||
|
||||
|
||||
class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
@classmethod
|
||||
@@ -392,6 +422,42 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
imax = torch.abs(blocks).argmax(dim=-1, keepdim=True)
|
||||
max_val = torch.gather(blocks, dim=-1, index=imax)
|
||||
|
||||
d = max_val / -16
|
||||
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1.0 / d)
|
||||
|
||||
q = torch.trunc((blocks.float() * id.float()) + 16.5).clamp(0, 31).to(torch.uint8)
|
||||
|
||||
qs = q.view(n_blocks, 2, block_size // 2)
|
||||
qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4)
|
||||
|
||||
qh = q.view(n_blocks, 32)
|
||||
qh_packed = torch.zeros((n_blocks, 4), dtype=torch.uint8, device=qh.device)
|
||||
|
||||
for i in range(4):
|
||||
qh_packed[:, i] = (
|
||||
(qh[:, i * 8 + 0] >> 4) |
|
||||
(qh[:, i * 8 + 1] >> 3 & 0x02) |
|
||||
(qh[:, i * 8 + 2] >> 2 & 0x04) |
|
||||
(qh[:, i * 8 + 3] >> 1 & 0x08) |
|
||||
(qh[:, i * 8 + 4] << 0 & 0x10) |
|
||||
(qh[:, i * 8 + 5] << 1 & 0x20) |
|
||||
(qh[:, i * 8 + 6] << 2 & 0x40) |
|
||||
(qh[:, i * 8 + 7] << 3 & 0x80)
|
||||
)
|
||||
|
||||
d = d.to(torch.float16).view(torch.uint8)
|
||||
|
||||
return torch.cat([d, qh_packed, qs], dim=-1)
|
||||
|
||||
|
||||
class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
@classmethod
|
||||
@@ -469,6 +535,16 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
d = torch.abs(blocks).max(dim=1, keepdim=True).values / 127
|
||||
ids = torch.where(d == 0, torch.zeros_like(d), 1 / d)
|
||||
qs = torch.round(blocks * ids)
|
||||
d = d.to(torch.float16).view(torch.uint8)
|
||||
qs = qs.to(torch.int8).view(torch.uint8)
|
||||
return torch.cat([d, qs], dim=1)
|
||||
|
||||
|
||||
class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user