fix: unsafe torch load

This commit is contained in:
Bingsu
2023-04-26 23:03:59 +09:00
parent d0c4701383
commit 44160f2b27

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from pathlib import Path
import gradio as gr
import torch
import modules
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
@@ -10,6 +11,7 @@ from adetailer.common import dilate_erode, is_all_black, offset
from modules import devices, scripts
from modules.paths import models_path
from modules.processing import StableDiffusionProcessingImg2Img, process_images
from modules.safe import load, unsafe_torch_load
from modules.shared import opts, state
AFTER_DETAILER = "After Detailer"
@@ -54,11 +56,13 @@ class ADetailerArgs:
}
def with_gc(func):
def adetailer_wrapper(func):
def wrapper(*args, **kwargs):
devices.torch_gc()
torch.load = unsafe_torch_load
result = func(*args, **kwargs)
devices.torch_gc()
torch.load = load
return result
return wrapper
@@ -266,7 +270,7 @@ class AfterDetailerScript(scripts.Script):
def get_args(*args):
return ADetailerArgs(*args)
@with_gc
@adetailer_wrapper
def postprocess_image(self, p, pp, *args_):
if getattr(p, "_disable_adetailer", False):
return