77 lines
2.4 KiB
Python
Executable File
77 lines
2.4 KiB
Python
Executable File
# By Forge
|
|
|
|
|
|
import torch
|
|
|
|
|
|
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
|
|
x = x.view(torch.uint8).view(x.size(0), -1)
|
|
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
|
reshaped = unpacked.view(x.size(0), -1)
|
|
reshaped = reshaped.view(torch.int8) - 8
|
|
return reshaped.view(torch.int32)
|
|
|
|
|
|
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
|
|
x = x.view(torch.uint8).view(x.size(0), -1)
|
|
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
|
reshaped = unpacked.view(x.size(0), -1)
|
|
return reshaped.view(torch.int32)
|
|
|
|
|
|
disable_all_optimizations = False
|
|
|
|
if not hasattr(torch, 'uint16'):
|
|
disable_all_optimizations = True
|
|
|
|
if disable_all_optimizations:
|
|
print('You are using PyTorch below version 2.3. Some optimizations will be disabled.')
|
|
|
|
if not disable_all_optimizations:
|
|
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
|
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
|
|
|
|
|
def quick_unpack_4bits(x):
|
|
if disable_all_optimizations:
|
|
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8
|
|
|
|
global native_4bits_lookup_table
|
|
|
|
s0 = x.size(0)
|
|
x = x.view(torch.uint16)
|
|
|
|
if native_4bits_lookup_table.device != x.device:
|
|
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)
|
|
|
|
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
|
|
y = y.view(torch.int8)
|
|
y = y.view(s0, -1)
|
|
|
|
return y
|
|
|
|
|
|
def quick_unpack_4bits_u(x):
|
|
if disable_all_optimizations:
|
|
return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1)
|
|
|
|
global native_4bits_lookup_table_u
|
|
|
|
s0 = x.size(0)
|
|
x = x.view(torch.uint16)
|
|
|
|
if native_4bits_lookup_table_u.device != x.device:
|
|
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)
|
|
|
|
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
|
|
y = y.view(torch.uint8)
|
|
y = y.view(s0, -1)
|
|
|
|
return y
|
|
|
|
|
|
def change_4bits_order(x):
|
|
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
|
|
z = y[:, ::2] | (y[:, 1::2] << 4)
|
|
return z
|