diff --git a/clip/clip.py b/clip/clip.py index 00abbc7..cf2ba38 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -122,16 +122,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a else: raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(model_path, map_location="cpu") + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") if not jit: model = build_model(state_dict or model.state_dict()).to(device)