mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 11:19:53 +00:00
fix(scripts): move contextmanagers
This commit is contained in:
0
aaaaaa/__init__.py
Normal file
0
aaaaaa/__init__.py
Normal file
55
aaaaaa/helper.py
Normal file
55
aaaaaa/helper.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from modules import safe
|
||||
from modules.shared import opts
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
|
||||
from types import SimpleNamespace
|
||||
|
||||
StableDiffusionProcessingTxt2Img = SimpleNamespace
|
||||
StableDiffusionProcessingImg2Img = SimpleNamespace
|
||||
else:
|
||||
from modules.processing import (
|
||||
StableDiffusionProcessingImg2Img,
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
)
|
||||
|
||||
PT = StableDiffusionProcessingTxt2Img | StableDiffusionProcessingImg2Img
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_torch_load():
|
||||
orig = torch.load
|
||||
try:
|
||||
torch.load = safe.unsafe_torch_load
|
||||
yield
|
||||
finally:
|
||||
torch.load = orig
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pause_total_tqdm():
|
||||
orig = opts.data.get("multiple_tqdm", True)
|
||||
try:
|
||||
opts.data["multiple_tqdm"] = False
|
||||
yield
|
||||
finally:
|
||||
opts.data["multiple_tqdm"] = orig
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preseve_prompts(p: PT):
|
||||
all_pt = copy(p.all_prompts)
|
||||
all_ng = copy(p.all_negative_prompts)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
p.all_prompts = all_pt
|
||||
p.all_negative_prompts = all_ng
|
||||
@@ -1 +1 @@
|
||||
__version__ = "24.4.0"
|
||||
__version__ = "24.4.1-dev.0"
|
||||
|
||||
@@ -4,7 +4,7 @@ import platform
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from contextlib import contextmanager, suppress
|
||||
from contextlib import suppress
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -12,11 +12,11 @@ from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, cast
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
from rich import print
|
||||
|
||||
import modules
|
||||
from aaaaaa.helper import change_torch_load, pause_total_tqdm, preseve_prompts
|
||||
from adetailer import (
|
||||
AFTER_DETAILER,
|
||||
__version__,
|
||||
@@ -44,7 +44,7 @@ from controlnet_ext import (
|
||||
controlnet_type,
|
||||
get_cn_models,
|
||||
)
|
||||
from modules import images, paths, safe, script_callbacks, scripts, shared
|
||||
from modules import images, paths, script_callbacks, scripts, shared
|
||||
from modules.devices import NansException
|
||||
from modules.processing import (
|
||||
Processed,
|
||||
@@ -86,37 +86,6 @@ print(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_torch_load():
|
||||
orig = torch.load
|
||||
try:
|
||||
torch.load = safe.unsafe_torch_load
|
||||
yield
|
||||
finally:
|
||||
torch.load = orig
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pause_total_tqdm():
|
||||
orig = opts.data.get("multiple_tqdm", True)
|
||||
try:
|
||||
opts.data["multiple_tqdm"] = False
|
||||
yield
|
||||
finally:
|
||||
opts.data["multiple_tqdm"] = orig
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preseve_prompts(p):
|
||||
all_pt = copy(p.all_prompts)
|
||||
all_ng = copy(p.all_negative_prompts)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
p.all_prompts = all_pt
|
||||
p.all_negative_prompts = all_ng
|
||||
|
||||
|
||||
class AfterDetailerScript(scripts.Script):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user