From 571c4586c5e7e1722c29dbb4a0c5a619f2a445ce Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 31 Jan 2024 17:29:09 -0800 Subject: [PATCH] remove shit --- modules/safe.py | 185 ++++-------------------------------------------- 1 file changed, 13 insertions(+), 172 deletions(-) diff --git a/modules/safe.py b/modules/safe.py index b1d08a79..068d6083 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,196 +1,37 @@ -# this code is adapted from the script contributed by anon from /h/ +TypedStorage = None -import pickle -import collections - -import torch -import numpy -import _codecs -import zipfile -import re - - -# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage -from modules import errors - -TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): - out = _codecs.encode(*args) - return out + pass -class RestrictedUnpickler(pickle.Unpickler): - extra_handler = None - - def persistent_load(self, saved_id): - assert saved_id[0] == 'storage' - - try: - return TypedStorage(_internal=True) - except TypeError: - return TypedStorage() # PyTorch before 2.0 does not have the _internal argument - - def find_class(self, module, name): - if self.extra_handler is not None: - res = self.extra_handler(module, name) - if res is not None: - return res - - if module == 'collections' and name == 'OrderedDict': - return getattr(collections, name) - if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: - return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: - return getattr(torch, name) - if module == 'torch.nn.modules.container' and name in ['ParameterDict']: - return getattr(torch.nn.modules.container, name) - if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: - return getattr(numpy.core.multiarray, name) - if module == 'numpy' and name in ['dtype', 'ndarray']: - return getattr(numpy, name) - if module == '_codecs' and name == 'encode': - return encode - if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': - import pytorch_lightning.callbacks - return pytorch_lightning.callbacks.model_checkpoint - if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': - import pytorch_lightning.callbacks.model_checkpoint - return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint - if module == "__builtin__" and name == 'set': - return set - - # Forbid everything else. - raise Exception(f"global '{module}/{name}' is forbidden") +class RestrictedUnpickler: + pass -# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' -allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") -data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") +allowed_zip_names_re = None +data_pkl_re = None + def check_zip_filenames(filename, names): - for name in names: - if allowed_zip_names_re.match(name): - continue - - raise Exception(f"bad file inside {filename}: {name}") + pass def check_pt(filename, extra_handler): - try: - - # new pytorch format is a zip file - with zipfile.ZipFile(filename) as z: - check_zip_filenames(filename, z.namelist()) - - # find filename of data.pkl in zip file: '/data.pkl' - data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] - if len(data_pkl_filenames) == 0: - raise Exception(f"data.pkl not found in {filename}") - if len(data_pkl_filenames) > 1: - raise Exception(f"Multiple data.pkl found in {filename}") - with z.open(data_pkl_filenames[0]) as file: - unpickler = RestrictedUnpickler(file) - unpickler.extra_handler = extra_handler - unpickler.load() - - except zipfile.BadZipfile: - - # if it's not a zip file, it's an old pytorch format, with five objects written to pickle - with open(filename, "rb") as file: - unpickler = RestrictedUnpickler(file) - unpickler.extra_handler = extra_handler - for _ in range(5): - unpickler.load() + pass def load(filename, *args, **kwargs): - return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) + pass def load_with_extra(filename, extra_handler=None, *args, **kwargs): - """ - this function is intended to be used by extensions that want to load models with - some extra classes in them that the usual unpickler would find suspicious. - - Use the extra_handler argument to specify a function that takes module and field name as text, - and returns that field's value: - - ```python - def extra(module, name): - if module == 'collections' and name == 'OrderedDict': - return collections.OrderedDict - - return None - - safe.load_with_extra('model.pt', extra_handler=extra) - ``` - - The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is - definitely unsafe. - """ - - from modules import shared - - try: - if not shared.cmd_opts.disable_safe_unpickle: - check_pt(filename, extra_handler) - - except pickle.UnpicklingError: - errors.report( - f"Error verifying pickled file from {filename}\n" - "-----> !!!! The file is most likely corrupted !!!! <-----\n" - "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", - exc_info=True, - ) - return None - except Exception: - errors.report( - f"Error verifying pickled file from {filename}\n" - f"The file may be malicious, so the program is not going to read it.\n" - f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", - exc_info=True, - ) - return None - - return unsafe_torch_load(filename, *args, **kwargs) + pass class Extra: - """ - A class for temporarily setting the global handler for when you can't explicitly call load_with_extra - (because it's not your code making the torch.load call). The intended use is like this: - -``` -import torch -from modules import safe - -def handler(module, name): - if module == 'torch' and name in ['float64', 'float16']: - return getattr(torch, name) - - return None - -with safe.Extra(handler): - x = torch.load('model.pt') -``` - """ - - def __init__(self, handler): - self.handler = handler - - def __enter__(self): - global global_extra_handler - - assert global_extra_handler is None, 'already inside an Extra() block' - global_extra_handler = self.handler - - def __exit__(self, exc_type, exc_val, exc_tb): - global global_extra_handler - - global_extra_handler = None + pass -unsafe_torch_load = torch.load -torch.load = load +unsafe_torch_load = None global_extra_handler = None