From 008a0bb7774d9eef428d14f58a74c1eb0b73bd10 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 20 Feb 2024 19:41:57 -0500 Subject: [PATCH] Fix converting files with docker command --- util/convert_safetensors.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/util/convert_safetensors.py b/util/convert_safetensors.py index 6bb2fb5..e249f2c 100644 --- a/util/convert_safetensors.py +++ b/util/convert_safetensors.py @@ -1,15 +1,19 @@ import torch -import argparse, os +import argparse, os, glob from safetensors.torch import save_file parser = argparse.ArgumentParser(description="Convert .bin/.pt files to .safetensors") -parser.add_argument("--unshare", action = "store_true", help="Detach tensors to prevent any from sharing memory") -parser.add_argument("input_files", nargs='+', type=str, help="Input file(s)") +parser.add_argument("--unshare", action="store_true", help="Detach tensors to prevent any from sharing memory") +parser.add_argument("input_files", nargs="+", type=str, help="Input file(s)") args = parser.parse_args() -for file in args.input_files: +tensor_files = [] +for file_pattern in args.input_files: + tensor_files.extend(glob.glob(file_pattern)) + +for file in tensor_files: print(f" -- Loading {file}...") - state_dict = torch.load(file, map_location = "cpu") + state_dict = torch.load(file, map_location="cpu") if args.unshare: for k in state_dict.keys(): @@ -17,4 +21,4 @@ for file in args.input_files: out_file = os.path.splitext(file)[0] + ".safetensors" print(f" -- Saving {out_file}...") - save_file(state_dict, out_file, metadata = {'format': 'pt'}) \ No newline at end of file + save_file(state_dict, out_file, metadata={"format": "pt"})