From fd0d25ba8aa4ef32d77571e93d438e5a6a949a9a Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 15 Aug 2024 03:08:25 -0700 Subject: [PATCH] fix type hints --- packages_3rdparty/gguf/quants.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 51283e57..08ce9437 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -125,7 +125,11 @@ class __Quant(ABC): cls.grid = grid.reshape((1, 1, *cls.grid_shape)) @classmethod - def dequantize_pytorch(cls, data: torch.Tensor, original_shape=torch.float16) -> torch.Tensor: + def quantize_pytorch(cls, data: torch.Tensor) -> torch.Tensor: + return cls.quantize_blocks_pytorch(data) + + @classmethod + def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) block_size, type_size = GGML_QUANT_SIZES[cls.qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) @@ -139,6 +143,11 @@ class __Quant(ABC): def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: raise NotImplementedError + @classmethod + @abstractmethod + def quantize_blocks_pytorch(cls, blocks) -> torch.Tensor: + raise NotImplementedError + @classmethod @abstractmethod def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: