From 44160f2b27f4a20a1e781a779ca8cddfb086f6cb Mon Sep 17 00:00:00 2001 From: Bingsu Date: Wed, 26 Apr 2023 23:03:59 +0900 Subject: [PATCH] fix: unsafe torch load --- scripts/!adetailer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index 3753c16..4de11cb 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path import gradio as gr +import torch import modules from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict @@ -10,6 +11,7 @@ from adetailer.common import dilate_erode, is_all_black, offset from modules import devices, scripts from modules.paths import models_path from modules.processing import StableDiffusionProcessingImg2Img, process_images +from modules.safe import load, unsafe_torch_load from modules.shared import opts, state AFTER_DETAILER = "After Detailer" @@ -54,11 +56,13 @@ class ADetailerArgs: } -def with_gc(func): +def adetailer_wrapper(func): def wrapper(*args, **kwargs): devices.torch_gc() + torch.load = unsafe_torch_load result = func(*args, **kwargs) devices.torch_gc() + torch.load = load return result return wrapper @@ -266,7 +270,7 @@ class AfterDetailerScript(scripts.Script): def get_args(*args): return ADetailerArgs(*args) - @with_gc + @adetailer_wrapper def postprocess_image(self, p, pp, *args_): if getattr(p, "_disable_adetailer", False): return