mirror of
https://github.com/Bing-su/adetailer.git
synced 2026-01-26 19:29:54 +00:00
60 lines
1.4 KiB
Python
60 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from copy import copy
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import torch
|
|
|
|
from modules import safe
|
|
from modules.shared import opts
|
|
|
|
if TYPE_CHECKING:
|
|
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
|
|
from types import SimpleNamespace
|
|
|
|
StableDiffusionProcessingTxt2Img = SimpleNamespace
|
|
StableDiffusionProcessingImg2Img = SimpleNamespace
|
|
else:
|
|
from modules.processing import (
|
|
StableDiffusionProcessingImg2Img,
|
|
StableDiffusionProcessingTxt2Img,
|
|
)
|
|
|
|
PT = StableDiffusionProcessingTxt2Img | StableDiffusionProcessingImg2Img
|
|
|
|
|
|
@contextmanager
|
|
def change_torch_load():
|
|
orig = torch.load
|
|
try:
|
|
torch.load = safe.unsafe_torch_load
|
|
yield
|
|
finally:
|
|
torch.load = orig
|
|
|
|
|
|
@contextmanager
|
|
def pause_total_tqdm():
|
|
orig = opts.data.get("multiple_tqdm", True)
|
|
try:
|
|
opts.data["multiple_tqdm"] = False
|
|
yield
|
|
finally:
|
|
opts.data["multiple_tqdm"] = orig
|
|
|
|
|
|
@contextmanager
|
|
def preseve_prompts(p: PT):
|
|
all_pt = copy(p.all_prompts)
|
|
all_ng = copy(p.all_negative_prompts)
|
|
try:
|
|
yield
|
|
finally:
|
|
p.all_prompts = all_pt
|
|
p.all_negative_prompts = all_ng
|
|
|
|
|
|
def copy_extra_params(extra_params: dict[str, Any]) -> dict[str, Any]:
|
|
return {k: v for k, v in extra_params.items() if not callable(v)}
|