mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-02-24 23:24:14 +00:00
371 lines
11 KiB
Python
371 lines
11 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Utilities for working with the local dataset cache.
|
|
This file is adapted from `AllenNLP <https://github.com/allenai/allennlp>`_.
|
|
and `huggingface <https://github.com/huggingface>`_.
|
|
"""
|
|
|
|
import fnmatch
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import tarfile
|
|
import tempfile
|
|
from functools import partial, wraps
|
|
from hashlib import sha256
|
|
from io import open
|
|
|
|
|
|
try:
|
|
from torch.hub import _get_torch_home
|
|
|
|
torch_cache_home = _get_torch_home()
|
|
except ImportError:
|
|
torch_cache_home = os.path.expanduser(
|
|
os.getenv(
|
|
"TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")
|
|
)
|
|
)
|
|
default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq")
|
|
|
|
try:
|
|
from urllib.parse import urlparse
|
|
except ImportError:
|
|
from urlparse import urlparse
|
|
|
|
try:
|
|
from pathlib import Path
|
|
|
|
PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path))
|
|
except (AttributeError, ImportError):
|
|
PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)
|
|
|
|
CONFIG_NAME = "config.json"
|
|
WEIGHTS_NAME = "pytorch_model.bin"
|
|
|
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
def load_archive_file(archive_file):
|
|
# redirect to the cache, if necessary
|
|
try:
|
|
resolved_archive_file = cached_path(archive_file, cache_dir=None)
|
|
except EnvironmentError:
|
|
logger.info(
|
|
"Archive name '{}' was not found in archive name list. "
|
|
"We assumed '{}' was a path or URL but couldn't find any file "
|
|
"associated to this path or URL.".format(
|
|
archive_file,
|
|
archive_file,
|
|
)
|
|
)
|
|
return None
|
|
|
|
if resolved_archive_file == archive_file:
|
|
logger.info("loading archive file {}".format(archive_file))
|
|
else:
|
|
logger.info(
|
|
"loading archive file {} from cache at {}".format(
|
|
archive_file, resolved_archive_file
|
|
)
|
|
)
|
|
|
|
# Extract archive to temp dir and replace .tar.bz2 if necessary
|
|
tempdir = None
|
|
if not os.path.isdir(resolved_archive_file):
|
|
tempdir = tempfile.mkdtemp()
|
|
logger.info(
|
|
"extracting archive file {} to temp dir {}".format(
|
|
resolved_archive_file, tempdir
|
|
)
|
|
)
|
|
ext = os.path.splitext(archive_file)[1][1:]
|
|
with tarfile.open(resolved_archive_file, "r:" + ext) as archive:
|
|
top_dir = os.path.commonprefix(archive.getnames())
|
|
archive.extractall(tempdir)
|
|
os.remove(resolved_archive_file)
|
|
shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file)
|
|
shutil.rmtree(tempdir)
|
|
|
|
return resolved_archive_file
|
|
|
|
|
|
def url_to_filename(url, etag=None):
|
|
"""
|
|
Convert `url` into a hashed filename in a repeatable way.
|
|
If `etag` is specified, append its hash to the URL's, delimited
|
|
by a period.
|
|
"""
|
|
url_bytes = url.encode("utf-8")
|
|
url_hash = sha256(url_bytes)
|
|
filename = url_hash.hexdigest()
|
|
|
|
if etag:
|
|
etag_bytes = etag.encode("utf-8")
|
|
etag_hash = sha256(etag_bytes)
|
|
filename += "." + etag_hash.hexdigest()
|
|
|
|
return filename
|
|
|
|
|
|
def filename_to_url(filename, cache_dir=None):
|
|
"""
|
|
Return the url and etag (which may be ``None``) stored for `filename`.
|
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = PYTORCH_FAIRSEQ_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
if not os.path.exists(cache_path):
|
|
raise EnvironmentError("file {} not found".format(cache_path))
|
|
|
|
meta_path = cache_path + ".json"
|
|
if not os.path.exists(meta_path):
|
|
raise EnvironmentError("file {} not found".format(meta_path))
|
|
|
|
with open(meta_path, encoding="utf-8") as meta_file:
|
|
metadata = json.load(meta_file)
|
|
url = metadata["url"]
|
|
etag = metadata["etag"]
|
|
|
|
return url, etag
|
|
|
|
|
|
def cached_path_from_pm(url_or_filename):
|
|
"""
|
|
Tries to cache the specified URL using PathManager class.
|
|
Returns the cached path if success otherwise failure.
|
|
"""
|
|
try:
|
|
from fairseq.file_io import PathManager
|
|
|
|
local_path = PathManager.get_local_path(url_or_filename)
|
|
return local_path
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def cached_path(url_or_filename, cache_dir=None):
|
|
"""
|
|
Given something that might be a URL (or might be a local path),
|
|
determine which. If it's a URL, download the file and cache it, and
|
|
return the path to the cached file. If it's already a local path,
|
|
make sure the file exists and then return the path.
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = PYTORCH_FAIRSEQ_CACHE
|
|
if isinstance(url_or_filename, Path):
|
|
url_or_filename = str(url_or_filename)
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
parsed = urlparse(url_or_filename)
|
|
|
|
if parsed.scheme in ("http", "https", "s3"):
|
|
# URL, so get it from the cache (downloading if necessary)
|
|
return get_from_cache(url_or_filename, cache_dir)
|
|
elif os.path.exists(url_or_filename):
|
|
# File, and it exists.
|
|
return url_or_filename
|
|
elif parsed.scheme == "":
|
|
# File, but it doesn't exist.
|
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
|
else:
|
|
cached_path = cached_path_from_pm(url_or_filename)
|
|
if cached_path:
|
|
return cached_path
|
|
# Something unknown
|
|
raise ValueError(
|
|
"unable to parse {} as a URL or as a local path".format(url_or_filename)
|
|
)
|
|
|
|
|
|
def split_s3_path(url):
|
|
"""Split a full s3 path into the bucket name and path."""
|
|
parsed = urlparse(url)
|
|
if not parsed.netloc or not parsed.path:
|
|
raise ValueError("bad s3 path {}".format(url))
|
|
bucket_name = parsed.netloc
|
|
s3_path = parsed.path
|
|
# Remove '/' at beginning of path.
|
|
if s3_path.startswith("/"):
|
|
s3_path = s3_path[1:]
|
|
return bucket_name, s3_path
|
|
|
|
|
|
def s3_request(func):
|
|
"""
|
|
Wrapper function for s3 requests in order to create more helpful error
|
|
messages.
|
|
"""
|
|
|
|
@wraps(func)
|
|
def wrapper(url, *args, **kwargs):
|
|
from botocore.exceptions import ClientError
|
|
|
|
try:
|
|
return func(url, *args, **kwargs)
|
|
except ClientError as exc:
|
|
if int(exc.response["Error"]["Code"]) == 404:
|
|
raise EnvironmentError("file {} not found".format(url))
|
|
else:
|
|
raise
|
|
|
|
return wrapper
|
|
|
|
|
|
@s3_request
|
|
def s3_etag(url):
|
|
"""Check ETag on S3 object."""
|
|
import boto3
|
|
|
|
s3_resource = boto3.resource("s3")
|
|
bucket_name, s3_path = split_s3_path(url)
|
|
s3_object = s3_resource.Object(bucket_name, s3_path)
|
|
return s3_object.e_tag
|
|
|
|
|
|
@s3_request
|
|
def s3_get(url, temp_file):
|
|
"""Pull a file directly from S3."""
|
|
import boto3
|
|
|
|
s3_resource = boto3.resource("s3")
|
|
bucket_name, s3_path = split_s3_path(url)
|
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
|
|
|
|
|
def request_wrap_timeout(func, url):
|
|
import requests
|
|
|
|
for attempt, timeout in enumerate([10, 20, 40, 60, 60]):
|
|
try:
|
|
return func(timeout=timeout)
|
|
except requests.exceptions.Timeout as e:
|
|
logger.warning(
|
|
"Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs",
|
|
url,
|
|
attempt,
|
|
timeout,
|
|
exc_info=e,
|
|
)
|
|
continue
|
|
raise RuntimeError(f"Unable to fetch file {url}")
|
|
|
|
|
|
def http_get(url, temp_file):
|
|
import requests
|
|
from tqdm import tqdm
|
|
|
|
req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
|
|
content_length = req.headers.get("Content-Length")
|
|
total = int(content_length) if content_length is not None else None
|
|
progress = tqdm(unit="B", total=total)
|
|
for chunk in req.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
progress.update(len(chunk))
|
|
temp_file.write(chunk)
|
|
progress.close()
|
|
|
|
|
|
def get_from_cache(url, cache_dir=None):
|
|
"""
|
|
Given a URL, look for the corresponding dataset in the local cache.
|
|
If it's not there, download it. Then return the path to the cached file.
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = PYTORCH_FAIRSEQ_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
if not os.path.exists(cache_dir):
|
|
os.makedirs(cache_dir)
|
|
|
|
# Get eTag to add to filename, if it exists.
|
|
if url.startswith("s3://"):
|
|
etag = s3_etag(url)
|
|
else:
|
|
try:
|
|
import requests
|
|
|
|
response = request_wrap_timeout(
|
|
partial(requests.head, url, allow_redirects=True), url
|
|
)
|
|
if response.status_code != 200:
|
|
etag = None
|
|
else:
|
|
etag = response.headers.get("ETag")
|
|
except RuntimeError:
|
|
etag = None
|
|
|
|
filename = url_to_filename(url, etag)
|
|
|
|
# get cache path to put the file
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
|
|
# If we don't have a connection (etag is None) and can't identify the file
|
|
# try to get the last downloaded one
|
|
if not os.path.exists(cache_path) and etag is None:
|
|
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
|
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
|
|
if matching_files:
|
|
cache_path = os.path.join(cache_dir, matching_files[-1])
|
|
|
|
if not os.path.exists(cache_path):
|
|
# Download to temporary file, then copy to cache dir once finished.
|
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
|
with tempfile.NamedTemporaryFile() as temp_file:
|
|
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
|
|
|
# GET file object
|
|
if url.startswith("s3://"):
|
|
s3_get(url, temp_file)
|
|
else:
|
|
http_get(url, temp_file)
|
|
|
|
# we are copying the file before closing it, so flush to avoid truncation
|
|
temp_file.flush()
|
|
# shutil.copyfileobj() starts at the current position, so go to the start
|
|
temp_file.seek(0)
|
|
|
|
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
|
with open(cache_path, "wb") as cache_file:
|
|
shutil.copyfileobj(temp_file, cache_file)
|
|
|
|
logger.info("creating metadata file for %s", cache_path)
|
|
meta = {"url": url, "etag": etag}
|
|
meta_path = cache_path + ".json"
|
|
with open(meta_path, "w") as meta_file:
|
|
output_string = json.dumps(meta)
|
|
meta_file.write(output_string)
|
|
|
|
logger.info("removing temp file %s", temp_file.name)
|
|
|
|
return cache_path
|
|
|
|
|
|
def read_set_from_file(filename):
|
|
"""
|
|
Extract a de-duped collection (set) of text from a file.
|
|
Expected file format is one item per line.
|
|
"""
|
|
collection = set()
|
|
with open(filename, "r", encoding="utf-8") as file_:
|
|
for line in file_:
|
|
collection.add(line.rstrip())
|
|
return collection
|
|
|
|
|
|
def get_file_extension(path, dot=True, lower=True):
|
|
ext = os.path.splitext(path)[1]
|
|
ext = ext if dot else ext[1:]
|
|
return ext.lower() if lower else ext
|