mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-11 07:59:49 +00:00
Speed up quant model loading and inference ...
... based on 3 evidences: 1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors. 2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view 3. “baking” model on GPU is significantly faster than computing on CPU when model load. mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
This commit is contained in:
@@ -104,11 +104,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
if storage_dtype in ['gguf']:
|
||||
from backend.operations_gguf import bake_gguf_model
|
||||
model.computation_dtype = torch.float16
|
||||
model = bake_gguf_model(model)
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
@@ -167,10 +162,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
model.initial_device = initial_device
|
||||
model.offload_device = offload_device
|
||||
|
||||
if storage_dtype in ['gguf']:
|
||||
from backend.operations_gguf import bake_gguf_model
|
||||
model = bake_gguf_model(model)
|
||||
|
||||
return model
|
||||
|
||||
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')
|
||||
|
||||
@@ -310,7 +310,7 @@ def state_dict_parameters(sd):
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k, v in state_dict.items():
|
||||
if hasattr(v, 'is_gguf'):
|
||||
if hasattr(v, 'gguf_cls'):
|
||||
return 'gguf'
|
||||
if 'bitsandbytes__nf4' in k:
|
||||
return 'nf4'
|
||||
@@ -337,6 +337,19 @@ def state_dict_dtype(state_dict):
|
||||
return major_dtype
|
||||
|
||||
|
||||
def bake_gguf_model(model):
|
||||
if getattr(model, 'gguf_baked', False):
|
||||
return
|
||||
|
||||
for p in model.parameters():
|
||||
gguf_cls = getattr(p, 'gguf_cls', None)
|
||||
if gguf_cls is not None:
|
||||
gguf_cls.bake(p)
|
||||
|
||||
model.gguf_baked = True
|
||||
return model
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None, return_split=False):
|
||||
module_mem = 0
|
||||
weight_mem = 0
|
||||
@@ -493,6 +506,8 @@ class LoadedModel:
|
||||
global signal_empty_cache
|
||||
signal_empty_cache = True
|
||||
|
||||
bake_gguf_model(self.real_model)
|
||||
|
||||
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||
@@ -642,7 +657,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
inference_memory = minimum_inference_memory()
|
||||
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
|
||||
|
||||
print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
||||
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
||||
|
||||
if estimated_remaining_memory < 0:
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
@@ -395,20 +395,22 @@ class ForgeOperationsGGUF(ForgeOperations):
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
if hasattr(self, 'dummy'):
|
||||
computation_dtype = self.dummy.dtype
|
||||
if computation_dtype not in [torch.float16, torch.bfloat16]:
|
||||
# GGUF cast only supports 16bits otherwise super slow
|
||||
computation_dtype = torch.float16
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
|
||||
self.weight.computation_dtype = computation_dtype
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
|
||||
self.bias.computation_dtype = computation_dtype
|
||||
del self.dummy
|
||||
else:
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = state_dict[prefix + 'weight']
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = state_dict[prefix + 'bias']
|
||||
if self.weight is not None and hasattr(self.weight, 'parent'):
|
||||
self.weight.parent = self
|
||||
if self.bias is not None and hasattr(self.bias, 'parent'):
|
||||
self.bias.parent = self
|
||||
return
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
|
||||
@@ -19,77 +19,40 @@ quants_mapping = {
|
||||
class ParameterGGUF(torch.nn.Parameter):
|
||||
def __init__(self, tensor=None, requires_grad=False, no_init=False):
|
||||
super().__init__()
|
||||
self.is_gguf = True
|
||||
|
||||
if no_init:
|
||||
return
|
||||
|
||||
self.gguf_type = tensor.tensor_type
|
||||
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
|
||||
self.gguf_cls = quants_mapping.get(self.gguf_type, None)
|
||||
self.parent = None
|
||||
self.gguf_cls = quants_mapping.get(tensor.tensor_type, None)
|
||||
self.real_shape = torch.Size(reversed(list(tensor.shape)))
|
||||
self.computation_dtype = torch.float16
|
||||
self.baked = False
|
||||
return
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.gguf_real_shape
|
||||
return self.real_shape
|
||||
|
||||
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
|
||||
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
|
||||
|
||||
def dequantize_as_pytorch_parameter(self):
|
||||
if self.parent is None:
|
||||
self.parent = torch.nn.Module()
|
||||
self.gguf_cls.bake_layer(self.parent, self, computation_dtype=torch.float16)
|
||||
if self.gguf_cls is not None:
|
||||
self.gguf_cls.bake(self)
|
||||
return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
def copy_with_data(self, data):
|
||||
new = ParameterGGUF(data, no_init=True)
|
||||
new.gguf_cls = self.gguf_cls
|
||||
new.parent = self.parent
|
||||
new.real_shape = self.real_shape
|
||||
new.computation_dtype = self.computation_dtype
|
||||
new.baked = self.baked
|
||||
return new
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self.copy_with_data(self.data.to(*args, **kwargs))
|
||||
|
||||
def pin_memory(self, device=None):
|
||||
new = ParameterGGUF(torch.Tensor.pin_memory(self, device=device), no_init=True)
|
||||
new.gguf_type = self.gguf_type
|
||||
new.gguf_real_shape = self.gguf_real_shape
|
||||
new.gguf_cls = self.gguf_cls
|
||||
new.parent = self.parent
|
||||
return new
|
||||
|
||||
@classmethod
|
||||
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent):
|
||||
new = ParameterGGUF(data, no_init=True)
|
||||
new.gguf_type = gguf_type
|
||||
new.gguf_real_shape = gguf_real_shape
|
||||
new.gguf_cls = gguf_cls
|
||||
new.parent = parent
|
||||
return new
|
||||
|
||||
|
||||
def bake_gguf_model(model):
|
||||
computation_dtype = model.computation_dtype
|
||||
|
||||
if computation_dtype not in [torch.float16, torch.bfloat16]:
|
||||
# Baking only supports 16bits otherwise super slow
|
||||
computation_dtype = torch.float16
|
||||
|
||||
backed_layer_counter = 0
|
||||
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'weight'):
|
||||
weight = m.weight
|
||||
if hasattr(weight, 'gguf_cls'):
|
||||
gguf_cls = weight.gguf_cls
|
||||
if gguf_cls is not None:
|
||||
backed_layer_counter += 1
|
||||
gguf_cls.bake_layer(m, weight, computation_dtype)
|
||||
|
||||
if backed_layer_counter > 0:
|
||||
print(f'GGUF backed {backed_layer_counter} layers.')
|
||||
|
||||
return model
|
||||
return self.copy_with_data(torch.Tensor.pin_memory(self, device=device))
|
||||
|
||||
|
||||
def dequantize_tensor(tensor):
|
||||
@@ -99,11 +62,9 @@ def dequantize_tensor(tensor):
|
||||
if not hasattr(tensor, 'gguf_cls'):
|
||||
return tensor
|
||||
|
||||
data = tensor
|
||||
gguf_cls = tensor.gguf_cls
|
||||
gguf_real_shape = tensor.gguf_real_shape
|
||||
|
||||
if gguf_cls is None:
|
||||
return data
|
||||
return tensor
|
||||
|
||||
return gguf_cls.dequantize_pytorch(data, gguf_real_shape)
|
||||
return gguf_cls.dequantize_pytorch(tensor)
|
||||
|
||||
@@ -387,13 +387,12 @@ class LoraLoader:
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
gguf_cls = getattr(weight, 'gguf_cls', None)
|
||||
gguf_parameter = None
|
||||
|
||||
if hasattr(weight, 'is_gguf'):
|
||||
if gguf_cls is not None:
|
||||
gguf_parameter = weight
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
gguf_cls = weight.gguf_cls
|
||||
gguf_type = weight.gguf_type
|
||||
gguf_real_shape = weight.gguf_real_shape
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
try:
|
||||
@@ -409,17 +408,9 @@ class LoraLoader:
|
||||
continue
|
||||
|
||||
if gguf_cls is not None:
|
||||
from backend.operations_gguf import ParameterGGUF
|
||||
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
||||
weight = ParameterGGUF.make(
|
||||
data=weight,
|
||||
gguf_type=gguf_type,
|
||||
gguf_cls=gguf_cls,
|
||||
gguf_real_shape=gguf_real_shape,
|
||||
parent=parent_layer
|
||||
)
|
||||
gguf_cls.bake_layer(parent_layer, weight, gguf_cls.computation_dtype)
|
||||
utils.set_attr_raw(self.model, key, weight)
|
||||
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
|
||||
gguf_parameter.baked = False
|
||||
gguf_cls.bake(gguf_parameter)
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
|
||||
@@ -157,9 +157,9 @@ def beautiful_print_gguf_state_dict_statics(state_dict):
|
||||
from gguf.constants import GGMLQuantizationType
|
||||
type_counts = {}
|
||||
for k, v in state_dict.items():
|
||||
gguf_type = getattr(v, 'gguf_type', None)
|
||||
if gguf_type is not None:
|
||||
type_name = GGMLQuantizationType(gguf_type).name
|
||||
gguf_cls = getattr(v, 'gguf_cls', None)
|
||||
if gguf_cls is not None:
|
||||
type_name = gguf_cls.__name__
|
||||
if type_name in type_counts:
|
||||
type_counts[type_name] += 1
|
||||
else:
|
||||
|
||||
163
packages_3rdparty/gguf/quants.py
vendored
163
packages_3rdparty/gguf/quants.py
vendored
@@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac
|
||||
import numpy as np
|
||||
|
||||
|
||||
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1)
|
||||
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1)
|
||||
|
||||
|
||||
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
|
||||
@@ -90,8 +90,6 @@ class __Quant(ABC):
|
||||
grid_map: tuple[int | float, ...] = ()
|
||||
grid_hex: bytes | None = None
|
||||
|
||||
computation_dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
def __init__(self):
|
||||
return TypeError("Quant conversion classes can't have instances")
|
||||
|
||||
@@ -144,29 +142,35 @@ class __Quant(ABC):
|
||||
return blocks.reshape(original_shape)
|
||||
|
||||
@classmethod
|
||||
def bake_layer(cls, layer, weight, computation_dtype):
|
||||
data = weight.data
|
||||
cls.computation_dtype = computation_dtype
|
||||
def bake(cls, parameter):
|
||||
if parameter.baked:
|
||||
return
|
||||
|
||||
data = parameter.data
|
||||
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
n_blocks = rows.numel() // cls.type_size
|
||||
blocks = rows.reshape((n_blocks, cls.type_size))
|
||||
weight.data = blocks
|
||||
cls.bake_layer_weight(layer, weight)
|
||||
parameter.data = blocks.contiguous()
|
||||
cls.bake_inner(parameter)
|
||||
parameter.baked = True
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
def bake_inner(cls, parameter):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
|
||||
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
|
||||
return blocks.reshape(original_shape)
|
||||
def dequantize_pytorch(cls, x):
|
||||
if not x.baked:
|
||||
raise ValueError('GGUF Tensor is not baked!')
|
||||
|
||||
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x)
|
||||
return blocks.view(x.shape)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
x = change_4bits_order(x)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
x = change_4bits_order(x).view(torch.uint8)
|
||||
parameter.data = torch.cat([d, x], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, qs = parent.quant_state_0, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, qs = quick_split(blocks, [2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
qs = quick_unpack_4bits(qs)
|
||||
return d * qs
|
||||
|
||||
@@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
|
||||
d, m, qs = quick_split(blocks, [2, 2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
qs = change_4bits_order(qs).view(torch.uint8)
|
||||
|
||||
qs = change_4bits_order(qs)
|
||||
parameter.data = torch.cat([d, m, qs], dim=-1).contiguous()
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if m.device != qs.device:
|
||||
m = m.to(device=qs.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, m, qs = quick_split(blocks, [2, 2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
m = m.view(parameter.computation_dtype)
|
||||
qs = quick_unpack_4bits_u(qs)
|
||||
return (d * qs) + m
|
||||
|
||||
@@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = quick_split(blocks, [2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
m = m.view(torch.float16).to(parameter.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
return (x * d)
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
x = x.view(torch.int8)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8)
|
||||
parameter.data = torch.cat([d, x], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
x = blocks
|
||||
d = parent.quant_state_0
|
||||
|
||||
if d.device != x.device:
|
||||
d = d.to(device=x.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
@@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
@@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
@@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
|
||||
def bake_inner(cls, parameter): # Only compute one time when model load
|
||||
# Copyright Forge 2024
|
||||
|
||||
blocks = weight.data
|
||||
K_SCALE_SIZE = 12
|
||||
blocks = parameter.data
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE])
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32))
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
|
||||
d = d.view(torch.uint8).reshape((n_blocks, -1))
|
||||
dm = dm.view(torch.uint8).reshape((n_blocks, -1))
|
||||
qs = qs.view(torch.uint8)
|
||||
|
||||
parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# Compute in each diffusion iteration
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
d, dm, qs = quick_split(blocks, [16, 16])
|
||||
d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1))
|
||||
dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1))
|
||||
qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32))
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if dm.device != qs.device:
|
||||
dm = dm.to(device=qs.device)
|
||||
|
||||
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
@@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
@@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# Written by ChatGPT
|
||||
n_blocks = blocks.shape[0]
|
||||
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
|
||||
scales = scales.view(torch.int8).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
scales = scales.view(torch.int8).to(parameter.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
Reference in New Issue
Block a user