fix(scripts): move contextmanagers

This commit is contained in:
Dowon
2024-04-13 16:04:40 +09:00
parent 89ee330271
commit 7d7dfb76a5
4 changed files with 59 additions and 35 deletions

0
aaaaaa/__init__.py Normal file
View File

55
aaaaaa/helper.py Normal file
View File

@@ -0,0 +1,55 @@
from __future__ import annotations
from contextlib import contextmanager
from copy import copy
from typing import TYPE_CHECKING
import torch
from modules import safe
from modules.shared import opts
if TYPE_CHECKING:
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
from types import SimpleNamespace
StableDiffusionProcessingTxt2Img = SimpleNamespace
StableDiffusionProcessingImg2Img = SimpleNamespace
else:
from modules.processing import (
StableDiffusionProcessingImg2Img,
StableDiffusionProcessingTxt2Img,
)
PT = StableDiffusionProcessingTxt2Img | StableDiffusionProcessingImg2Img
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
@contextmanager
def preseve_prompts(p: PT):
all_pt = copy(p.all_prompts)
all_ng = copy(p.all_negative_prompts)
try:
yield
finally:
p.all_prompts = all_pt
p.all_negative_prompts = all_ng

View File

@@ -1 +1 @@
__version__ = "24.4.0"
__version__ = "24.4.1-dev.0"

View File

@@ -4,7 +4,7 @@ import platform
import re
import sys
import traceback
from contextlib import contextmanager, suppress
from contextlib import suppress
from copy import copy
from functools import partial
from pathlib import Path
@@ -12,11 +12,11 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Any, NamedTuple, cast
import gradio as gr
import torch
from PIL import Image, ImageChops
from rich import print
import modules
from aaaaaa.helper import change_torch_load, pause_total_tqdm, preseve_prompts
from adetailer import (
AFTER_DETAILER,
__version__,
@@ -44,7 +44,7 @@ from controlnet_ext import (
controlnet_type,
get_cn_models,
)
from modules import images, paths, safe, script_callbacks, scripts, shared
from modules import images, paths, script_callbacks, scripts, shared
from modules.devices import NansException
from modules.processing import (
Processed,
@@ -86,37 +86,6 @@ print(
)
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
@contextmanager
def preseve_prompts(p):
all_pt = copy(p.all_prompts)
all_ng = copy(p.all_negative_prompts)
try:
yield
finally:
p.all_prompts = all_pt
p.all_negative_prompts = all_ng
class AfterDetailerScript(scripts.Script):
def __init__(self):
super().__init__()