mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
import json
|
|
from collections import OrderedDict
|
|
|
|
from safetensors import safe_open
|
|
|
|
from info import software_meta
|
|
|
|
|
|
def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
|
# stringify the meta and reparse OrderedDict to replace [name] with name
|
|
meta_string = json.dumps(meta)
|
|
if name is not None:
|
|
meta_string = meta_string.replace("[name]", name)
|
|
save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
|
|
save_meta["software"] = software_meta
|
|
# safetensors can only be one level deep
|
|
for key, value in save_meta.items():
|
|
# if not float, int, bool, or str, convert to json string
|
|
if not isinstance(value, str):
|
|
save_meta[key] = json.dumps(value)
|
|
return save_meta
|
|
|
|
|
|
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
|
parsed_meta = OrderedDict()
|
|
for key, value in meta.items():
|
|
try:
|
|
parsed_meta[key] = json.loads(value)
|
|
except json.decoder.JSONDecodeError:
|
|
parsed_meta[key] = value
|
|
return parsed_meta
|
|
|
|
|
|
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
|
with safe_open(file_path, framework="pt") as f:
|
|
metadata = f.metadata()
|
|
return parse_metadata_from_safetensors(metadata)
|