fix offline quant lora precision

This commit is contained in:
layerdiffusion
2024-08-31 13:12:23 -07:00
parent 79b25a8235
commit a8a81d3d77
2 changed files with 15 additions and 20 deletions

View File

@@ -377,9 +377,7 @@ class LoraLoader:
continue
if gguf_cls is not None:
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
gguf_parameter.baked = False
gguf_cls.bake(gguf_parameter)
gguf_cls.quantize_pytorch(weight, gguf_parameter)
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))

View File

@@ -129,17 +129,14 @@ class __Quant(ABC):
cls.grid = grid.reshape((1, 1, *cls.grid_shape))
@classmethod
def quantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
# Copyright Forge 2024, AGPL V3 + CC-BY SA
original_shape = [x for x in original_shape]
original_shape[-1] = -1
original_shape = tuple(original_shape)
def quantize_pytorch(cls, data, parent) -> torch.Tensor:
if not parent.baked:
raise ValueError('GGUF Tensor is not baked!')
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
blocks = data.reshape(-1, block_size)
blocks = cls.quantize_blocks_pytorch(blocks, block_size, type_size)
return blocks.reshape(original_shape)
parent.data = cls.quantize_blocks_pytorch(blocks, block_size, type_size, parent).contiguous()
return parent
@classmethod
def bake(cls, parameter):
@@ -175,7 +172,7 @@ class __Quant(ABC):
@classmethod
@abstractmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" in "Diffusion in Low Bits" (on the top line of this page) to use this LoRA.')
@classmethod
@@ -323,7 +320,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
return d * qs
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# Copyright Forge 2024, AGPL V3 + CC-BY SA
n_blocks = blocks.shape[0]
@@ -336,10 +333,10 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
qs = torch.trunc((blocks * id) + 8.5).clip(0, 15).to(torch.uint8)
qs = qs.reshape((n_blocks, 2, block_size // 2))
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
qs = qs.reshape((n_blocks, block_size // 2, 2))
qs = qs[..., 0] | (qs[..., 1] << 4)
d = d.to(torch.float16).view(torch.uint8)
d = d.to(parent.computation_dtype).view(torch.uint8)
return torch.cat([d, qs], dim=-1)
@@ -467,7 +464,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
return (d * qs)
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# Copyright Forge 2024, AGPL V3 + CC-BY SA
n_blocks = blocks.shape[0]
@@ -611,13 +608,13 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
return x * d
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def quantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# Copyright Forge 2024, AGPL V3 + CC-BY SA
d = torch.abs(blocks).max(dim=1, keepdim=True).values / 127
ids = torch.where(d == 0, torch.zeros_like(d), 1 / d)
qs = torch.round(blocks * ids)
d = d.to(torch.float16).view(torch.uint8)
qs = qs.to(torch.int8).view(torch.uint8)
d = d.to(parent.computation_dtype).view(torch.int8)
qs = qs.to(torch.int8)
return torch.cat([d, qs], dim=1)