From 4d120f3ec35b30bd0f992f5d8af2d793aad98d2a Mon Sep 17 00:00:00 2001 From: John Sutor Date: Thu, 14 Jul 2022 08:40:02 -0400 Subject: [PATCH] Add PyTorch Hub configuration file (#259) --- hubconf.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 hubconf.py diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..520b354 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,42 @@ +from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models +import re +import string + +dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] + +# For compatibility (cannot include special characters in function name) +model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} + +def _create_hub_entrypoint(model): + def entrypoint(**kwargs): + return _load(model, **kwargs) + + entrypoint.__doc__ = f"""Loads the {model} CLIP model + + Parameters + ---------- + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The {model} CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + return entrypoint + +def tokenize(): + return _tokenize + +_entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} + +globals().update(_entrypoints) \ No newline at end of file