restrict baking to 16bits

This commit is contained in:
layerdiffusion
2024-08-26 06:16:13 -07:00
parent 7cd94babdd
commit f22b80ef94

View File

@@ -70,6 +70,11 @@ class ParameterGGUF(torch.nn.Parameter):
def bake_gguf_model(model):
computation_dtype = model.computation_dtype
if computation_dtype not in [torch.float16, torch.bfloat16]:
# Baking only supports 16bits otherwise super slow
computation_dtype = torch.float16
backed_layer_counter = 0
for m in model.modules():