mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-03-13 17:30:01 +00:00
fix: unsafe torch load
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
import modules
|
||||
from adetailer import __version__, get_models, mediapipe_predict, ultralytics_predict
|
||||
@@ -10,6 +11,7 @@ from adetailer.common import dilate_erode, is_all_black, offset
|
||||
from modules import devices, scripts
|
||||
from modules.paths import models_path
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.safe import load, unsafe_torch_load
|
||||
from modules.shared import opts, state
|
||||
|
||||
AFTER_DETAILER = "After Detailer"
|
||||
@@ -54,11 +56,13 @@ class ADetailerArgs:
|
||||
}
|
||||
|
||||
|
||||
def with_gc(func):
|
||||
def adetailer_wrapper(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
devices.torch_gc()
|
||||
torch.load = unsafe_torch_load
|
||||
result = func(*args, **kwargs)
|
||||
devices.torch_gc()
|
||||
torch.load = load
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -266,7 +270,7 @@ class AfterDetailerScript(scripts.Script):
|
||||
def get_args(*args):
|
||||
return ADetailerArgs(*args)
|
||||
|
||||
@with_gc
|
||||
@adetailer_wrapper
|
||||
def postprocess_image(self, p, pp, *args_):
|
||||
if getattr(p, "_disable_adetailer", False):
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user