From 7d7dfb76a509405342e0579b587ba2f7a8a44fa1 Mon Sep 17 00:00:00 2001 From: Dowon Date: Sat, 13 Apr 2024 16:04:40 +0900 Subject: [PATCH] fix(scripts): move contextmanagers --- aaaaaa/__init__.py | 0 aaaaaa/helper.py | 55 ++++++++++++++++++++++++++++++++++++++++ adetailer/__version__.py | 2 +- scripts/!adetailer.py | 37 +++------------------------ 4 files changed, 59 insertions(+), 35 deletions(-) create mode 100644 aaaaaa/__init__.py create mode 100644 aaaaaa/helper.py diff --git a/aaaaaa/__init__.py b/aaaaaa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aaaaaa/helper.py b/aaaaaa/helper.py new file mode 100644 index 0000000..aa9f7f6 --- /dev/null +++ b/aaaaaa/helper.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from contextlib import contextmanager +from copy import copy +from typing import TYPE_CHECKING + +import torch + +from modules import safe +from modules.shared import opts + +if TYPE_CHECKING: + # 타입 체커가 빨간 줄을 긋지 않게 하는 편법 + from types import SimpleNamespace + + StableDiffusionProcessingTxt2Img = SimpleNamespace + StableDiffusionProcessingImg2Img = SimpleNamespace +else: + from modules.processing import ( + StableDiffusionProcessingImg2Img, + StableDiffusionProcessingTxt2Img, + ) + +PT = StableDiffusionProcessingTxt2Img | StableDiffusionProcessingImg2Img + + +@contextmanager +def change_torch_load(): + orig = torch.load + try: + torch.load = safe.unsafe_torch_load + yield + finally: + torch.load = orig + + +@contextmanager +def pause_total_tqdm(): + orig = opts.data.get("multiple_tqdm", True) + try: + opts.data["multiple_tqdm"] = False + yield + finally: + opts.data["multiple_tqdm"] = orig + + +@contextmanager +def preseve_prompts(p: PT): + all_pt = copy(p.all_prompts) + all_ng = copy(p.all_negative_prompts) + try: + yield + finally: + p.all_prompts = all_pt + p.all_negative_prompts = all_ng diff --git a/adetailer/__version__.py b/adetailer/__version__.py index b43269d..4e34747 100644 --- a/adetailer/__version__.py +++ b/adetailer/__version__.py @@ -1 +1 @@ -__version__ = "24.4.0" +__version__ = "24.4.1-dev.0" diff --git a/scripts/!adetailer.py b/scripts/!adetailer.py index eb46a1d..36b8c34 100644 --- a/scripts/!adetailer.py +++ b/scripts/!adetailer.py @@ -4,7 +4,7 @@ import platform import re import sys import traceback -from contextlib import contextmanager, suppress +from contextlib import suppress from copy import copy from functools import partial from pathlib import Path @@ -12,11 +12,11 @@ from textwrap import dedent from typing import TYPE_CHECKING, Any, NamedTuple, cast import gradio as gr -import torch from PIL import Image, ImageChops from rich import print import modules +from aaaaaa.helper import change_torch_load, pause_total_tqdm, preseve_prompts from adetailer import ( AFTER_DETAILER, __version__, @@ -44,7 +44,7 @@ from controlnet_ext import ( controlnet_type, get_cn_models, ) -from modules import images, paths, safe, script_callbacks, scripts, shared +from modules import images, paths, script_callbacks, scripts, shared from modules.devices import NansException from modules.processing import ( Processed, @@ -86,37 +86,6 @@ print( ) -@contextmanager -def change_torch_load(): - orig = torch.load - try: - torch.load = safe.unsafe_torch_load - yield - finally: - torch.load = orig - - -@contextmanager -def pause_total_tqdm(): - orig = opts.data.get("multiple_tqdm", True) - try: - opts.data["multiple_tqdm"] = False - yield - finally: - opts.data["multiple_tqdm"] = orig - - -@contextmanager -def preseve_prompts(p): - all_pt = copy(p.all_prompts) - all_ng = copy(p.all_negative_prompts) - try: - yield - finally: - p.all_prompts = all_pt - p.all_negative_prompts = all_ng - - class AfterDetailerScript(scripts.Script): def __init__(self): super().__init__()