From fb62214a326cc3c7050ff00569638f0ca1eac412 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 15 Aug 2024 00:29:19 -0700 Subject: [PATCH] rewrite some functions --- backend/operations_gguf.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 7a01ee86..9461238c 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -9,27 +9,21 @@ def functional_linear_gguf(x, weight, bias=None): return torch.nn.functional.linear(x, weight, bias) -def dequantize_tensor(tensor, dtype=torch.float16): - # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) - +def dequantize_tensor(tensor, target_dtype=torch.float16): if tensor is None: return None data = torch.tensor(tensor.data) - qtype = tensor.gguf_type - oshape = tensor.gguf_real_shape + gguf_type = tensor.gguf_type + gguf_real_shape = tensor.gguf_real_shape - if qtype == gguf.GGMLQuantizationType.F32: - return data.to(dtype) - elif qtype == gguf.GGMLQuantizationType.F16: - return data.to(dtype) - elif qtype in dequantize_functions: - # this is the main pytorch op - return dequantize(data, qtype, oshape).to(dtype) - else: - # this is incredibly slow - new = gguf.quants.dequantize(data.cpu().numpy(), qtype) - return torch.from_numpy(new).to(data.device, dtype=dtype) + if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]: + return data.to(target_dtype) + + if gguf_type not in dequantize_functions: + raise NotImplementedError(f'Quant type {gguf_type} not implemented!') + + return dequantize(data, gguf_type, gguf_real_shape).to(target_dtype) def dequantize(data, qtype, oshape):