gguf preview

This commit is contained in:
layerdiffusion
2024-08-14 22:26:00 -07:00
parent 59790f2cb4
commit d8b83a9501
8 changed files with 190 additions and 7 deletions

View File

@@ -1,3 +1,4 @@
import gguf
import torch
import os
import json
@@ -5,6 +6,31 @@ import safetensors.torch
import backend.misc.checkpoint_pickle
class ParameterGGUF(torch.nn.Parameter):
def __init__(self, tensor=None, requires_grad=False, no_init=False):
super().__init__()
self.is_gguf = True
if no_init:
return
self.gguf_type = tensor.tensor_type
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
@property
def shape(self):
return self.gguf_real_shape
def __new__(cls, tensor=None, requires_grad=False, no_init=False):
return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad)
def to(self, *args, **kwargs):
new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True)
new.gguf_type = self.gguf_type
new.gguf_real_shape = self.gguf_real_shape
return new
def read_arbitrary_config(directory):
config_path = os.path.join(directory, 'config.json')
@@ -22,6 +48,11 @@ def load_torch_file(ckpt, safe_load=False, device=None):
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
elif ckpt.lower().endswith(".gguf"):
reader = gguf.GGUFReader(ckpt)
sd = {}
for tensor in reader.tensors:
sd[str(tensor.name)] = ParameterGGUF(tensor)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames: