Speed up quant model loading and inference ...

... based on 3 evidences:
1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors.
2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view
3. “baking” model on GPU is significantly faster than computing on CPU when model load.

mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
This commit is contained in:
layerdiffusion
2024-08-30 00:49:05 -07:00
parent 3d62fa9598
commit 4c9380c46a
7 changed files with 126 additions and 181 deletions

View File

@@ -104,11 +104,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
if storage_dtype in ['gguf']:
from backend.operations_gguf import bake_gguf_model
model.computation_dtype = torch.float16
model = bake_gguf_model(model)
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
@@ -167,10 +162,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
model.initial_device = initial_device
model.offload_device = offload_device
if storage_dtype in ['gguf']:
from backend.operations_gguf import bake_gguf_model
model = bake_gguf_model(model)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')

View File

@@ -310,7 +310,7 @@ def state_dict_parameters(sd):
def state_dict_dtype(state_dict):
for k, v in state_dict.items():
if hasattr(v, 'is_gguf'):
if hasattr(v, 'gguf_cls'):
return 'gguf'
if 'bitsandbytes__nf4' in k:
return 'nf4'
@@ -337,6 +337,19 @@ def state_dict_dtype(state_dict):
return major_dtype
def bake_gguf_model(model):
if getattr(model, 'gguf_baked', False):
return
for p in model.parameters():
gguf_cls = getattr(p, 'gguf_cls', None)
if gguf_cls is not None:
gguf_cls.bake(p)
model.gguf_baked = True
return model
def module_size(module, exclude_device=None, return_split=False):
module_mem = 0
weight_mem = 0
@@ -493,6 +506,8 @@ class LoadedModel:
global signal_empty_cache
signal_empty_cache = True
bake_gguf_model(self.real_model)
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
if is_intel_xpu() and not args.disable_ipex_hijack:
@@ -642,7 +657,7 @@ def load_models_gpu(models, memory_required=0):
inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM

View File

@@ -395,20 +395,22 @@ class ForgeOperationsGGUF(ForgeOperations):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
if hasattr(self, 'dummy'):
computation_dtype = self.dummy.dtype
if computation_dtype not in [torch.float16, torch.bfloat16]:
# GGUF cast only supports 16bits otherwise super slow
computation_dtype = torch.float16
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device)
self.weight.computation_dtype = computation_dtype
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device)
self.bias.computation_dtype = computation_dtype
del self.dummy
else:
if prefix + 'weight' in state_dict:
self.weight = state_dict[prefix + 'weight']
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias']
if self.weight is not None and hasattr(self.weight, 'parent'):
self.weight.parent = self
if self.bias is not None and hasattr(self.bias, 'parent'):
self.bias.parent = self
return
def _apply(self, fn, recurse=True):

View File

@@ -19,77 +19,40 @@ quants_mapping = {
class ParameterGGUF(torch.nn.Parameter):
def __init__(self, tensor=None, requires_grad=False, no_init=False):
super().__init__()
self.is_gguf = True
if no_init:
return
self.gguf_type = tensor.tensor_type
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
self.gguf_cls = quants_mapping.get(self.gguf_type, None)
self.parent = None
self.gguf_cls = quants_mapping.get(tensor.tensor_type, None)
self.real_shape = torch.Size(reversed(list(tensor.shape)))
self.computation_dtype = torch.float16
self.baked = False
return
@property
def shape(self):
return self.gguf_real_shape
return self.real_shape
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
def dequantize_as_pytorch_parameter(self):
if self.parent is None:
self.parent = torch.nn.Module()
self.gguf_cls.bake_layer(self.parent, self, computation_dtype=torch.float16)
if self.gguf_cls is not None:
self.gguf_cls.bake(self)
return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False)
def to(self, *args, **kwargs):
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
new.gguf_type = self.gguf_type
new.gguf_real_shape = self.gguf_real_shape
def copy_with_data(self, data):
new = ParameterGGUF(data, no_init=True)
new.gguf_cls = self.gguf_cls
new.parent = self.parent
new.real_shape = self.real_shape
new.computation_dtype = self.computation_dtype
new.baked = self.baked
return new
def to(self, *args, **kwargs):
return self.copy_with_data(self.data.to(*args, **kwargs))
def pin_memory(self, device=None):
new = ParameterGGUF(torch.Tensor.pin_memory(self, device=device), no_init=True)
new.gguf_type = self.gguf_type
new.gguf_real_shape = self.gguf_real_shape
new.gguf_cls = self.gguf_cls
new.parent = self.parent
return new
@classmethod
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent):
new = ParameterGGUF(data, no_init=True)
new.gguf_type = gguf_type
new.gguf_real_shape = gguf_real_shape
new.gguf_cls = gguf_cls
new.parent = parent
return new
def bake_gguf_model(model):
computation_dtype = model.computation_dtype
if computation_dtype not in [torch.float16, torch.bfloat16]:
# Baking only supports 16bits otherwise super slow
computation_dtype = torch.float16
backed_layer_counter = 0
for m in model.modules():
if hasattr(m, 'weight'):
weight = m.weight
if hasattr(weight, 'gguf_cls'):
gguf_cls = weight.gguf_cls
if gguf_cls is not None:
backed_layer_counter += 1
gguf_cls.bake_layer(m, weight, computation_dtype)
if backed_layer_counter > 0:
print(f'GGUF backed {backed_layer_counter} layers.')
return model
return self.copy_with_data(torch.Tensor.pin_memory(self, device=device))
def dequantize_tensor(tensor):
@@ -99,11 +62,9 @@ def dequantize_tensor(tensor):
if not hasattr(tensor, 'gguf_cls'):
return tensor
data = tensor
gguf_cls = tensor.gguf_cls
gguf_real_shape = tensor.gguf_real_shape
if gguf_cls is None:
return data
return tensor
return gguf_cls.dequantize_pytorch(data, gguf_real_shape)
return gguf_cls.dequantize_pytorch(tensor)

View File

@@ -387,13 +387,12 @@ class LoraLoader:
from backend.operations_bnb import functional_dequantize_4bit
weight = functional_dequantize_4bit(weight)
gguf_cls, gguf_type, gguf_real_shape = None, None, None
gguf_cls = getattr(weight, 'gguf_cls', None)
gguf_parameter = None
if hasattr(weight, 'is_gguf'):
if gguf_cls is not None:
gguf_parameter = weight
from backend.operations_gguf import dequantize_tensor
gguf_cls = weight.gguf_cls
gguf_type = weight.gguf_type
gguf_real_shape = weight.gguf_real_shape
weight = dequantize_tensor(weight)
try:
@@ -409,17 +408,9 @@ class LoraLoader:
continue
if gguf_cls is not None:
from backend.operations_gguf import ParameterGGUF
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
weight = ParameterGGUF.make(
data=weight,
gguf_type=gguf_type,
gguf_cls=gguf_cls,
gguf_real_shape=gguf_real_shape,
parent=parent_layer
)
gguf_cls.bake_layer(parent_layer, weight, gguf_cls.computation_dtype)
utils.set_attr_raw(self.model, key, weight)
gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape)
gguf_parameter.baked = False
gguf_cls.bake(gguf_parameter)
continue
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))

View File

@@ -157,9 +157,9 @@ def beautiful_print_gguf_state_dict_statics(state_dict):
from gguf.constants import GGMLQuantizationType
type_counts = {}
for k, v in state_dict.items():
gguf_type = getattr(v, 'gguf_type', None)
if gguf_type is not None:
type_name = GGMLQuantizationType(gguf_type).name
gguf_cls = getattr(v, 'gguf_cls', None)
if gguf_cls is not None:
type_name = gguf_cls.__name__
if type_name in type_counts:
type_counts[type_name] += 1
else:

View File

@@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac
import numpy as np
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1)
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1)
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
@@ -90,8 +90,6 @@ class __Quant(ABC):
grid_map: tuple[int | float, ...] = ()
grid_hex: bytes | None = None
computation_dtype: torch.dtype = torch.bfloat16
def __init__(self):
return TypeError("Quant conversion classes can't have instances")
@@ -144,29 +142,35 @@ class __Quant(ABC):
return blocks.reshape(original_shape)
@classmethod
def bake_layer(cls, layer, weight, computation_dtype):
data = weight.data
cls.computation_dtype = computation_dtype
def bake(cls, parameter):
if parameter.baked:
return
data = parameter.data
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // cls.type_size
blocks = rows.reshape((n_blocks, cls.type_size))
weight.data = blocks
cls.bake_layer_weight(layer, weight)
parameter.data = blocks.contiguous()
cls.bake_inner(parameter)
parameter.baked = True
return
@classmethod
def bake_layer_weight(cls, layer, weight):
def bake_inner(cls, parameter):
pass
@classmethod
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
return blocks.reshape(original_shape)
def dequantize_pytorch(cls, x):
if not x.baked:
raise ValueError('GGUF Tensor is not baked!')
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x)
return blocks.view(x.shape)
@classmethod
@abstractmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
raise NotImplementedError
@classmethod
@@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
return (d * qs.astype(np.float32))
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(cls.computation_dtype)
x = change_4bits_order(x)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
x = change_4bits_order(x).view(torch.uint8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
d, qs = parent.quant_state_0, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, qs = quick_split(blocks, [2])
d = d.view(parameter.computation_dtype)
qs = quick_unpack_4bits(qs)
return d * qs
@@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
return (d * qs) + m
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, m, qs = quick_split(blocks, [2, 2])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
qs = change_4bits_order(qs).view(torch.uint8)
qs = change_4bits_order(qs)
parameter.data = torch.cat([d, m, qs], dim=-1).contiguous()
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
if m.device != qs.device:
m = m.to(device=qs.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, m, qs = quick_split(blocks, [2, 2])
d = d.view(parameter.computation_dtype)
m = m.view(parameter.computation_dtype)
qs = quick_unpack_4bits_u(qs)
return (d * qs) + m
@@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
n_blocks = blocks.shape[0]
d, qh, qs = quick_split(blocks, [2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
n_blocks = blocks.shape[0]
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
m = m.view(torch.float16).to(parameter.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
return (x * d)
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, x = quick_split(blocks, [2])
x = x.view(torch.int8)
d = d.view(torch.float16).to(cls.computation_dtype)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
x = blocks
d = parent.quant_state_0
if d.device != x.device:
d = d.to(device=x.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, x = quick_split(blocks, [2])
d = d.view(parameter.computation_dtype)
return x * d
@classmethod
@@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
return qs.reshape((n_blocks, -1))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
@@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
return (dl * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
d = d.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
lscales = lscales.reshape((n_blocks, 16))
@@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
return (d * qs - dm).reshape((n_blocks, QK_K))
@classmethod
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
def bake_inner(cls, parameter): # Only compute one time when model load
# Copyright Forge 2024
blocks = weight.data
K_SCALE_SIZE = 12
blocks = parameter.data
n_blocks = blocks.shape[0]
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE])
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype)
qs = qs.reshape((n_blocks, -1, 1, 32))
qs = change_4bits_order(qs)
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
d = d.view(torch.uint8).reshape((n_blocks, -1))
dm = dm.view(torch.uint8).reshape((n_blocks, -1))
qs = qs.view(torch.uint8)
parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# Compute in each diffusion iteration
n_blocks = blocks.shape[0]
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
d, dm, qs = quick_split(blocks, [16, 16])
d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1))
dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1))
qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32))
if d.device != qs.device:
d = d.to(device=qs.device)
if dm.device != qs.device:
dm = dm.to(device=qs.device)
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
@@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
return (d * q - dm).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
QK_K = 256
K_SCALE_SIZE = 12
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
@@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
return (d * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# Written by ChatGPT
n_blocks = blocks.shape[0]
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
scales = scales.view(torch.int8).to(cls.computation_dtype)
d = d.view(torch.float16).to(cls.computation_dtype)
scales = scales.view(torch.int8).to(parameter.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))