From c73ff3724c6f0ee49c890bbc04de4d58a1c75bb7 Mon Sep 17 00:00:00 2001 From: lllyasviel <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 21 Aug 2024 03:23:34 -0700 Subject: [PATCH] update some codes related to win32 --- modules_forge/forge_space.py | 23 +++++++++++++++++------ modules_forge/patch_basic.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/modules_forge/forge_space.py b/modules_forge/forge_space.py index f64e98e9..8a1d24e9 100644 --- a/modules_forge/forge_space.py +++ b/modules_forge/forge_space.py @@ -43,25 +43,36 @@ def find_free_port(server_name, start_port=None): port += 1 +def long_path_prefix(path): + if os.name == 'nt' and not path.startswith("\\\\?\\") and not os.path.exists(path): + return f"\\\\?\\{path}" + return path + + def remove_dir(dir_path): + dir_path = long_path_prefix(dir_path) for root, dirs, files in os.walk(dir_path, topdown=False): for name in files: + file_path = os.path.join(root, name) + file_path = long_path_prefix(file_path) try: - os.remove(os.path.join(root, name)) + os.remove(file_path) except Exception as e: - print(f"Error removing file {os.path.join(root, name)}: {e}") + print(f"Error removing file {file_path}: {e}") for name in dirs: + dir_to_remove = os.path.join(root, name) + dir_to_remove = long_path_prefix(dir_to_remove) try: - os.rmdir(os.path.join(root, name)) + os.rmdir(dir_to_remove) except Exception as e: - print(f"Error removing directory {os.path.join(root, name)}: {e}") + print(f"Error removing directory {dir_to_remove}: {e}") try: os.rmdir(dir_path) print(f"Deleted: {dir_path}") - except: - print(f'Something went wrong when trying to delete a folder. You may try to manually delete the folder [{dir_path}].') + except Exception as e: + print(f"Error removing directory {dir_path}: {e}. You may try to manually delete the folder.") return diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index d22bc003..822e2838 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -6,6 +6,7 @@ import warnings import gradio.networking import safetensors.torch +from pathlib import Path from tqdm import tqdm @@ -64,6 +65,12 @@ def always_show_tqdm(*args, **kwargs): return tqdm(*args, **kwargs) +def long_path_prefix(path: Path) -> Path: + if os.name == 'nt' and not str(path).startswith("\\\\?\\") and not path.exists(): + return Path("\\\\?\\" + str(path)) + return path + + def patch_all_basics(): import logging from huggingface_hub import file_download @@ -71,6 +78,15 @@ def patch_all_basics(): from transformers.dynamic_module_utils import logger logger.setLevel(logging.ERROR) + from huggingface_hub.file_download import _download_to_tmp_and_move as original_download_to_tmp_and_move + + def patched_download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download): + incomplete_path = long_path_prefix(incomplete_path) + destination_path = long_path_prefix(destination_path) + return original_download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download) + + file_download._download_to_tmp_and_move = patched_download_to_tmp_and_move + gradio.networking.url_ok = gradio_url_ok_fix build_loaded(safetensors.torch, 'load_file') build_loaded(torch, 'load')