mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
The default encoding on linux is utf8, but Windows uses cp1252 which isn't compatible with some unicode characters. Signed-off-by: kingbri <bdashore3@proton.me>
85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
import argparse, json, math, os
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
|
|
parser = argparse.ArgumentParser(description = "Split .safetensors file into shards")
|
|
parser.add_argument("input_file", type = str, help = "Path to input file")
|
|
parser.add_argument("shard_size", type = int, help = "Shard size in megabytes")
|
|
args = parser.parse_args()
|
|
|
|
input_file = args.input_file
|
|
input_base, _ = os.path.splitext(input_file)
|
|
shard_size = args.shard_size * 1024**2
|
|
|
|
# Create tensor map
|
|
|
|
def _tsize(st, key):
|
|
|
|
tslice = st.get_slice(key)
|
|
shape = tslice.get_shape()
|
|
numel = 1
|
|
for x in shape: numel *= x
|
|
dtype = tslice.get_dtype()
|
|
del tslice
|
|
if dtype == "I32": return numel * 4
|
|
elif dtype == "I16": return numel * 2
|
|
elif dtype == "F16": return numel * 2
|
|
elif dtype == "F32": return numel * 4
|
|
else: raise ValueError("Unexpected datatype: " + key)
|
|
|
|
num_files = 0
|
|
current_size = shard_size + 1
|
|
total_size = 0
|
|
tensor_map = []
|
|
|
|
print(f" -- Scanning tensors in {input_file}")
|
|
|
|
with safe_open(input_file, framework = "pt", device = "cpu") as f:
|
|
|
|
for key in f.keys():
|
|
|
|
tensor_size = _tsize(f, key)
|
|
total_size += tensor_size
|
|
|
|
if current_size + tensor_size > shard_size:
|
|
|
|
num_files += 1
|
|
current_size = 0
|
|
current_list = []
|
|
tensor_map.append(current_list)
|
|
|
|
current_size += tensor_size
|
|
current_list.append(key)
|
|
|
|
# Split into output files
|
|
|
|
weight_map = {}
|
|
|
|
for file_index, keys in enumerate(tensor_map):
|
|
|
|
shard = {}
|
|
shard_filename = f"{input_base}-{file_index + 1:05}-of-{num_files:05}.safetensors"
|
|
|
|
with safe_open(input_file, framework = "pt", device = "cpu") as f:
|
|
for key in keys:
|
|
print(f" -- Reading: {key}")
|
|
shard[key] = f.get_tensor(key)
|
|
weight_map[key] = shard_filename
|
|
|
|
print(f" -- Writing: {shard_filename}")
|
|
save_file(shard, shard_filename)
|
|
|
|
# Compile index
|
|
|
|
index = { "metadata": { "total_size": total_size }, "weight_map": weight_map }
|
|
index_filename = f"{input_file}.index.json"
|
|
|
|
print(f" -- Writing: {index_filename}")
|
|
|
|
with open(index_filename, 'w', encoding = "utf8") as f:
|
|
json.dump(index, f, indent = 2)
|
|
|
|
# Done
|
|
|
|
print(f" -- Done")
|