# # this code is adapted from the script contributed by anon from /h/ # # import pickle # import collections # # import torch # import numpy # import _codecs # import zipfile # import re # # # # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # from modules import errors # # TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage # # def encode(*args): # out = _codecs.encode(*args) # return out # # # class RestrictedUnpickler(pickle.Unpickler): # extra_handler = None # # def persistent_load(self, saved_id): # assert saved_id[0] == 'storage' # # try: # return TypedStorage(_internal=True) # except TypeError: # return TypedStorage() # PyTorch before 2.0 does not have the _internal argument # # def find_class(self, module, name): # if self.extra_handler is not None: # res = self.extra_handler(module, name) # if res is not None: # return res # # if module == 'collections' and name == 'OrderedDict': # return getattr(collections, name) # if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: # return getattr(torch._utils, name) # if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: # return getattr(torch, name) # if module == 'torch.nn.modules.container' and name in ['ParameterDict']: # return getattr(torch.nn.modules.container, name) # if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: # return getattr(numpy.core.multiarray, name) # if module == 'numpy' and name in ['dtype', 'ndarray']: # return getattr(numpy, name) # if module == '_codecs' and name == 'encode': # return encode # if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': # import pytorch_lightning.callbacks # return pytorch_lightning.callbacks.model_checkpoint # if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': # import pytorch_lightning.callbacks.model_checkpoint # return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint # if module == "__builtin__" and name == 'set': # return set # # # Forbid everything else. # raise Exception(f"global '{module}/{name}' is forbidden") # # # # Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/' # allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$") # data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") # # def check_zip_filenames(filename, names): # for name in names: # if allowed_zip_names_re.match(name): # continue # # raise Exception(f"bad file inside {filename}: {name}") # # # def check_pt(filename, extra_handler): # try: # # # new pytorch format is a zip file # with zipfile.ZipFile(filename) as z: # check_zip_filenames(filename, z.namelist()) # # # find filename of data.pkl in zip file: '/data.pkl' # data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] # if len(data_pkl_filenames) == 0: # raise Exception(f"data.pkl not found in {filename}") # if len(data_pkl_filenames) > 1: # raise Exception(f"Multiple data.pkl found in {filename}") # with z.open(data_pkl_filenames[0]) as file: # unpickler = RestrictedUnpickler(file) # unpickler.extra_handler = extra_handler # unpickler.load() # # except zipfile.BadZipfile: # # # if it's not a zip file, it's an old pytorch format, with five objects written to pickle # with open(filename, "rb") as file: # unpickler = RestrictedUnpickler(file) # unpickler.extra_handler = extra_handler # for _ in range(5): # unpickler.load() # # # def load(filename, *args, **kwargs): # return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) # # # def load_with_extra(filename, extra_handler=None, *args, **kwargs): # """ # this function is intended to be used by extensions that want to load models with # some extra classes in them that the usual unpickler would find suspicious. # # Use the extra_handler argument to specify a function that takes module and field name as text, # and returns that field's value: # # ```python # def extra(module, name): # if module == 'collections' and name == 'OrderedDict': # return collections.OrderedDict # # return None # # safe.load_with_extra('model.pt', extra_handler=extra) # ``` # # The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is # definitely unsafe. # """ # # from modules import shared # # try: # if not shared.cmd_opts.disable_safe_unpickle: # check_pt(filename, extra_handler) # # except pickle.UnpicklingError: # errors.report( # f"Error verifying pickled file from {filename}\n" # "-----> !!!! The file is most likely corrupted !!!! <-----\n" # "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", # exc_info=True, # ) # return None # except Exception: # errors.report( # f"Error verifying pickled file from {filename}\n" # f"The file may be malicious, so the program is not going to read it.\n" # f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", # exc_info=True, # ) # return None # # return unsafe_torch_load(filename, *args, **kwargs) # # # class Extra: # """ # A class for temporarily setting the global handler for when you can't explicitly call load_with_extra # (because it's not your code making the torch.load call). The intended use is like this: # # ``` # import torch # from modules import safe # # def handler(module, name): # if module == 'torch' and name in ['float64', 'float16']: # return getattr(torch, name) # # return None # # with safe.Extra(handler): # x = torch.load('model.pt') # ``` # """ # # def __init__(self, handler): # self.handler = handler # # def __enter__(self): # global global_extra_handler # # assert global_extra_handler is None, 'already inside an Extra() block' # global_extra_handler = self.handler # # def __exit__(self, exc_type, exc_val, exc_tb): # global global_extra_handler # # global_extra_handler = None # # # unsafe_torch_load = torch.load # global_extra_handler = None