rewrite some functions

This commit is contained in:
layerdiffusion
2024-08-15 00:29:19 -07:00
parent c74f603ea2
commit fb62214a32

View File

@@ -9,27 +9,21 @@ def functional_linear_gguf(x, weight, bias=None):
return torch.nn.functional.linear(x, weight, bias)
def dequantize_tensor(tensor, dtype=torch.float16):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
def dequantize_tensor(tensor, target_dtype=torch.float16):
if tensor is None:
return None
data = torch.tensor(tensor.data)
qtype = tensor.gguf_type
oshape = tensor.gguf_real_shape
gguf_type = tensor.gguf_type
gguf_real_shape = tensor.gguf_real_shape
if qtype == gguf.GGMLQuantizationType.F32:
return data.to(dtype)
elif qtype == gguf.GGMLQuantizationType.F16:
return data.to(dtype)
elif qtype in dequantize_functions:
# this is the main pytorch op
return dequantize(data, qtype, oshape).to(dtype)
else:
# this is incredibly slow
new = gguf.quants.dequantize(data.cpu().numpy(), qtype)
return torch.from_numpy(new).to(data.device, dtype=dtype)
if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]:
return data.to(target_dtype)
if gguf_type not in dequantize_functions:
raise NotImplementedError(f'Quant type {gguf_type} not implemented!')
return dequantize(data, gguf_type, gguf_real_shape).to(target_dtype)
def dequantize(data, qtype, oshape):