feat: change unsafe pickling

This commit is contained in:
Dowon
2025-02-18 20:18:42 +09:00
parent 589412052d
commit 06100063b3
2 changed files with 15 additions and 4 deletions

View File

@@ -1,15 +1,17 @@
from __future__ import annotations
import os
from contextlib import contextmanager
from copy import copy
from typing import TYPE_CHECKING, Any, Union
from unittest.mock import patch
import torch
from PIL import Image
from typing_extensions import Protocol
from modules import safe
from modules.shared import opts
from modules.shared import cmd_opts, opts
if TYPE_CHECKING:
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
@@ -36,6 +38,15 @@ def change_torch_load():
torch.load = orig
@contextmanager
def disable_safe_unpickle():
with (
patch.dict(os.environ, {"TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD": "1"}, clear=False),
patch.object(cmd_opts, "disable_safe_unpickle", True),
):
yield
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)

View File

@@ -18,8 +18,8 @@ import modules
from aaaaaa.conditional import create_binary_mask, schedulers
from aaaaaa.helper import (
PPImage,
change_torch_load,
copy_extra_params,
disable_safe_unpickle,
pause_total_tqdm,
preserve_prompts,
)
@@ -825,8 +825,8 @@ class AfterDetailerScript(scripts.Script):
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
else:
with change_torch_load():
ad_model = self.get_ad_model(args.ad_model)
with disable_safe_unpickle():
pred = ultralytics_predict(
ad_model,
image=pp.image,