update some codes related to win32

This commit is contained in:
lllyasviel
2024-08-21 03:23:34 -07:00
committed by GitHub
parent 3e253012b5
commit c73ff3724c
2 changed files with 33 additions and 6 deletions

View File

@@ -43,25 +43,36 @@ def find_free_port(server_name, start_port=None):
port += 1
def long_path_prefix(path):
if os.name == 'nt' and not path.startswith("\\\\?\\") and not os.path.exists(path):
return f"\\\\?\\{path}"
return path
def remove_dir(dir_path):
dir_path = long_path_prefix(dir_path)
for root, dirs, files in os.walk(dir_path, topdown=False):
for name in files:
file_path = os.path.join(root, name)
file_path = long_path_prefix(file_path)
try:
os.remove(os.path.join(root, name))
os.remove(file_path)
except Exception as e:
print(f"Error removing file {os.path.join(root, name)}: {e}")
print(f"Error removing file {file_path}: {e}")
for name in dirs:
dir_to_remove = os.path.join(root, name)
dir_to_remove = long_path_prefix(dir_to_remove)
try:
os.rmdir(os.path.join(root, name))
os.rmdir(dir_to_remove)
except Exception as e:
print(f"Error removing directory {os.path.join(root, name)}: {e}")
print(f"Error removing directory {dir_to_remove}: {e}")
try:
os.rmdir(dir_path)
print(f"Deleted: {dir_path}")
except:
print(f'Something went wrong when trying to delete a folder. You may try to manually delete the folder [{dir_path}].')
except Exception as e:
print(f"Error removing directory {dir_path}: {e}. You may try to manually delete the folder.")
return

View File

@@ -6,6 +6,7 @@ import warnings
import gradio.networking
import safetensors.torch
from pathlib import Path
from tqdm import tqdm
@@ -64,6 +65,12 @@ def always_show_tqdm(*args, **kwargs):
return tqdm(*args, **kwargs)
def long_path_prefix(path: Path) -> Path:
if os.name == 'nt' and not str(path).startswith("\\\\?\\") and not path.exists():
return Path("\\\\?\\" + str(path))
return path
def patch_all_basics():
import logging
from huggingface_hub import file_download
@@ -71,6 +78,15 @@ def patch_all_basics():
from transformers.dynamic_module_utils import logger
logger.setLevel(logging.ERROR)
from huggingface_hub.file_download import _download_to_tmp_and_move as original_download_to_tmp_and_move
def patched_download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download):
incomplete_path = long_path_prefix(incomplete_path)
destination_path = long_path_prefix(destination_path)
return original_download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download)
file_download._download_to_tmp_and_move = patched_download_to_tmp_and_move
gradio.networking.url_ok = gradio_url_ok_fix
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')