mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-09 01:19:58 +00:00
rewrite some functions
This commit is contained in:
@@ -9,27 +9,21 @@ def functional_linear_gguf(x, weight, bias=None):
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def dequantize_tensor(tensor, dtype=torch.float16):
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
def dequantize_tensor(tensor, target_dtype=torch.float16):
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
data = torch.tensor(tensor.data)
|
||||
qtype = tensor.gguf_type
|
||||
oshape = tensor.gguf_real_shape
|
||||
gguf_type = tensor.gguf_type
|
||||
gguf_real_shape = tensor.gguf_real_shape
|
||||
|
||||
if qtype == gguf.GGMLQuantizationType.F32:
|
||||
return data.to(dtype)
|
||||
elif qtype == gguf.GGMLQuantizationType.F16:
|
||||
return data.to(dtype)
|
||||
elif qtype in dequantize_functions:
|
||||
# this is the main pytorch op
|
||||
return dequantize(data, qtype, oshape).to(dtype)
|
||||
else:
|
||||
# this is incredibly slow
|
||||
new = gguf.quants.dequantize(data.cpu().numpy(), qtype)
|
||||
return torch.from_numpy(new).to(data.device, dtype=dtype)
|
||||
if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]:
|
||||
return data.to(target_dtype)
|
||||
|
||||
if gguf_type not in dequantize_functions:
|
||||
raise NotImplementedError(f'Quant type {gguf_type} not implemented!')
|
||||
|
||||
return dequantize(data, gguf_type, gguf_real_shape).to(target_dtype)
|
||||
|
||||
|
||||
def dequantize(data, qtype, oshape):
|
||||
|
||||
Reference in New Issue
Block a user