diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index d2c43c8c..3d9cec66 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -70,6 +70,11 @@ class ParameterGGUF(torch.nn.Parameter): def bake_gguf_model(model): computation_dtype = model.computation_dtype + + if computation_dtype not in [torch.float16, torch.bfloat16]: + # Baking only supports 16bits otherwise super slow + computation_dtype = torch.float16 + backed_layer_counter = 0 for m in model.modules():