diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 51283e57..08ce9437 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -125,7 +125,11 @@ class __Quant(ABC): cls.grid = grid.reshape((1, 1, *cls.grid_shape)) @classmethod - def dequantize_pytorch(cls, data: torch.Tensor, original_shape=torch.float16) -> torch.Tensor: + def quantize_pytorch(cls, data: torch.Tensor) -> torch.Tensor: + return cls.quantize_blocks_pytorch(data) + + @classmethod + def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) block_size, type_size = GGML_QUANT_SIZES[cls.qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) @@ -139,6 +143,11 @@ class __Quant(ABC): def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: raise NotImplementedError + @classmethod + @abstractmethod + def quantize_blocks_pytorch(cls, blocks) -> torch.Tensor: + raise NotImplementedError + @classmethod @abstractmethod def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: