From acf99dd74eed27222da74b96466a526c9db640c8 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 26 Aug 2024 06:51:48 -0700 Subject: [PATCH] fix old version of pytorch --- packages_3rdparty/gguf/quick_4bits_ops.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/packages_3rdparty/gguf/quick_4bits_ops.py b/packages_3rdparty/gguf/quick_4bits_ops.py index 97404bbc..89711964 100644 --- a/packages_3rdparty/gguf/quick_4bits_ops.py +++ b/packages_3rdparty/gguf/quick_4bits_ops.py @@ -8,7 +8,7 @@ def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x): x = x.view(torch.uint8).view(x.size(0), -1) unpacked = torch.stack([x & 15, x >> 4], dim=-1) reshaped = unpacked.view(x.size(0), -1) - reshaped = reshaped.to(torch.int8) - 8 + reshaped = reshaped.view(torch.int8) - 8 return reshaped.view(torch.int32) @@ -19,11 +19,23 @@ def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x): return reshaped.view(torch.int32) -native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] -native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] +disable_all_optimizations = False + +if not hasattr(torch, 'uint16'): + disable_all_optimizations = True + +if disable_all_optimizations: + print('You are using PyTorch below version 2.3. Some optimizations will be disabled.') + +if not disable_all_optimizations: + native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] + native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] def quick_unpack_4bits(x): + if disable_all_optimizations: + return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8 + global native_4bits_lookup_table s0 = x.size(0) @@ -40,6 +52,9 @@ def quick_unpack_4bits(x): def quick_unpack_4bits_u(x): + if disable_all_optimizations: + return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1) + global native_4bits_lookup_table_u s0 = x.size(0)