diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index bb530325..87048a20 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -1,8 +1,26 @@ import torch import os +import time +import httpx +import warnings +import gradio.networking import safetensors.torch +def gradio_url_ok_fix(url: str) -> bool: + try: + for _ in range(5): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + r = httpx.head(url, timeout=999, verify=False) + if r.status_code in (200, 401, 302): + return True + time.sleep(0.500) + except (ConnectionError, httpx.ConnectError): + return False + return False + + def build_loaded(module, loader_name): original_loader_name = loader_name + '_origin' @@ -38,6 +56,7 @@ def build_loaded(module, loader_name): def patch_all_basics(): + gradio.networking.url_ok = gradio_url_ok_fix build_loaded(safetensors.torch, 'load_file') build_loaded(torch, 'load') return