mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-19 06:19:56 +00:00
fix offline quant lora precision
This commit is contained in:
@@ -377,9 +377,7 @@ class LoraLoader:
|
||||
continue
|
||||
|
||||
if gguf_cls is not None:
|
||||
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
|
||||
gguf_parameter.baked = False
|
||||
gguf_cls.bake(gguf_parameter)
|
||||
gguf_cls.quantize_pytorch(weight, gguf_parameter)
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
|
||||
31
packages_3rdparty/gguf/quants.py
vendored
31
packages_3rdparty/gguf/quants.py
vendored
@@ -129,17 +129,14 @@ class __Quant(ABC):
|
||||
cls.grid = grid.reshape((1, 1, *cls.grid_shape))
|
||||
|
||||
@classmethod
|
||||
def quantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
original_shape = [x for x in original_shape]
|
||||
original_shape[-1] = -1
|
||||
original_shape = tuple(original_shape)
|
||||
def quantize_pytorch(cls, data, parent) -> torch.Tensor:
|
||||
if not parent.baked:
|
||||
raise ValueError('GGUF Tensor is not baked!')
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
blocks = data.reshape(-1, block_size)
|
||||
blocks = cls.quantize_blocks_pytorch(blocks, block_size, type_size)
|
||||
return blocks.reshape(original_shape)
|
||||
parent.data = cls.quantize_blocks_pytorch(blocks, block_size, type_size, parent).contiguous()
|
||||
return parent
|
||||
|
||||
@classmethod
|
||||
def bake(cls, parameter):
|
||||
@@ -175,7 +172,7 @@ class __Quant(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" in "Diffusion in Low Bits" (on the top line of this page) to use this LoRA.')
|
||||
|
||||
@classmethod
|
||||
@@ -323,7 +320,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
return d * qs
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
@@ -336,10 +333,10 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
|
||||
qs = torch.trunc((blocks * id) + 8.5).clip(0, 15).to(torch.uint8)
|
||||
|
||||
qs = qs.reshape((n_blocks, 2, block_size // 2))
|
||||
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
|
||||
qs = qs.reshape((n_blocks, block_size // 2, 2))
|
||||
qs = qs[..., 0] | (qs[..., 1] << 4)
|
||||
|
||||
d = d.to(torch.float16).view(torch.uint8)
|
||||
d = d.to(parent.computation_dtype).view(torch.uint8)
|
||||
|
||||
return torch.cat([d, qs], dim=-1)
|
||||
|
||||
@@ -467,7 +464,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
return (d * qs)
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
@@ -611,13 +608,13 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
# Copyright Forge 2024, AGPL V3 + CC-BY SA
|
||||
d = torch.abs(blocks).max(dim=1, keepdim=True).values / 127
|
||||
ids = torch.where(d == 0, torch.zeros_like(d), 1 / d)
|
||||
qs = torch.round(blocks * ids)
|
||||
d = d.to(torch.float16).view(torch.uint8)
|
||||
qs = qs.to(torch.int8).view(torch.uint8)
|
||||
d = d.to(parent.computation_dtype).view(torch.int8)
|
||||
qs = qs.to(torch.int8)
|
||||
return torch.cat([d, qs], dim=1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user