mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-03-15 01:48:20 +00:00
Merge branch 'dev' into gradio4
This commit is contained in:
@@ -230,6 +230,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
@@ -747,6 +748,10 @@ class Api:
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -19,21 +19,21 @@ parser.add_argument("--skip-install", action='store_true', help="launch.py argum
|
||||
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
||||
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=normalized_filepath, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=normalized_filepath, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||
@@ -48,12 +48,13 @@ parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to g
|
||||
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
||||
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--codeformer-models-path", type=normalized_filepath, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=normalized_filepath, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT'))
|
||||
parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
@@ -83,18 +84,18 @@ parser.add_argument("--freeze-specific-settings", type=str, help='disable editin
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=normalized_filepath, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
@@ -120,4 +121,6 @@ parser.add_argument('--api-server-stop', action='store_true', help='enable serve
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
||||
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
||||
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
||||
|
||||
@@ -3,8 +3,7 @@ import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors, shared
|
||||
from modules import torch_utils
|
||||
from modules import errors, shared, npu_specific
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@@ -58,6 +57,9 @@ def get_optimal_device_name():
|
||||
if has_xpu():
|
||||
return xpu_specific.get_xpu_device_string()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return npu_specific.get_npu_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
@@ -85,6 +87,16 @@ def torch_gc():
|
||||
if has_xpu():
|
||||
xpu_specific.torch_xpu_gc()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
torch_npu_set_device()
|
||||
npu_specific.torch_npu_gc()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||
if npu_specific.has_npu:
|
||||
torch.npu.set_device(0)
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
@@ -141,7 +153,12 @@ def manual_cast_forward(target_dtype):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = torch_utils.get_param(self).dtype
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
@@ -170,7 +187,7 @@ def manual_cast(target_dtype):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention and has_xpu():
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
@@ -252,4 +269,3 @@ def first_time_calculation():
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
|
||||
@@ -21,7 +21,10 @@ def calculate_sha256(filename):
|
||||
|
||||
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
try:
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if title not in hashes:
|
||||
return None
|
||||
|
||||
@@ -321,13 +321,16 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
return res
|
||||
|
||||
|
||||
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||
if not shared.cmd_opts.unix_filenames_sanitization:
|
||||
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||
else:
|
||||
invalid_filename_chars = '/'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = 128
|
||||
max_filename_part_length = shared.cmd_opts.filenames_max_length
|
||||
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||
|
||||
|
||||
|
||||
@@ -365,6 +365,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
if "Emphasis" not in res:
|
||||
res["Emphasis"] = "Original"
|
||||
|
||||
if "Refiner switch by sampling steps" not in res:
|
||||
res["Refiner switch by sampling steps"] = False
|
||||
|
||||
infotext_versions.backcompat(res)
|
||||
|
||||
for key in skip_fields:
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
|
||||
v160 = version.parse("1.6.0")
|
||||
v170_tsnr = version.parse("v1.7.0-225")
|
||||
v180 = version.parse("1.8.0")
|
||||
|
||||
|
||||
def parse_version(text):
|
||||
@@ -31,9 +32,14 @@ def backcompat(d):
|
||||
if ver is None:
|
||||
return
|
||||
|
||||
if ver < v160:
|
||||
if ver < v160 and '[' in d.get('Prompt', ''):
|
||||
d["Old prompt editing timelines"] = True
|
||||
|
||||
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
||||
d["Pad conds v0"] = True
|
||||
|
||||
if ver < v170_tsnr:
|
||||
d["Downcast alphas_cumprod"] = True
|
||||
|
||||
if ver < v180 and d.get('Refiner'):
|
||||
d["Refiner switch by sampling steps"] = True
|
||||
|
||||
@@ -142,13 +142,14 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
from modules import devices
|
||||
devices.torch_npu_set_device()
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
from modules import devices
|
||||
devices.first_time_calculation()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
|
||||
@@ -17,7 +17,7 @@ clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
re_topn = re.compile(r"\.top(\d+)$")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
|
||||
@@ -55,7 +55,7 @@ and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
@@ -188,7 +188,7 @@ def git_clone(url, dir, name, commithash=None):
|
||||
return
|
||||
|
||||
try:
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
@@ -251,7 +251,6 @@ def list_extensions(settings_file):
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
settings = {}
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
@@ -339,6 +338,7 @@ def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
@@ -422,6 +422,13 @@ def prepare_environment():
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file_for_npu):
|
||||
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||
|
||||
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||
startup_timer.record("install requirements_for_npu")
|
||||
|
||||
if not args.skip_install:
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
|
||||
31
modules/npu_specific.py
Normal file
31
modules/npu_specific.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import importlib
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def check_for_npu():
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
import torch_npu
|
||||
|
||||
try:
|
||||
# Will raise a RuntimeError if no NPU is found
|
||||
_ = torch_npu.npu.device_count()
|
||||
return torch.npu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_npu_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"npu:{shared.cmd_opts.device_id}"
|
||||
return "npu:0"
|
||||
|
||||
|
||||
def torch_npu_gc():
|
||||
with torch.npu.device(get_npu_device_string()):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
has_npu = check_for_npu()
|
||||
@@ -198,6 +198,8 @@ class Options:
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
@@ -4,6 +4,10 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
normalized_filepath = lambda filepath: str(Path(filepath).absolute())
|
||||
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
@@ -74,16 +74,18 @@ def uncrop(image, dest_size, paste_loc):
|
||||
|
||||
def apply_overlay(image, paste_loc, overlay):
|
||||
if overlay is None:
|
||||
return image
|
||||
return image, image.copy()
|
||||
|
||||
if paste_loc is not None:
|
||||
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
|
||||
|
||||
original_denoised_image = image.copy()
|
||||
|
||||
image = image.convert('RGBA')
|
||||
image.alpha_composite(overlay)
|
||||
image = image.convert('RGB')
|
||||
|
||||
return image
|
||||
return image, original_denoised_image
|
||||
|
||||
def create_binary_mask(image, round=True):
|
||||
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||
@@ -455,6 +457,7 @@ class StableDiffusionProcessing:
|
||||
self.height,
|
||||
opts.fp8_storage,
|
||||
opts.cache_fp16_weight,
|
||||
opts.emphasis,
|
||||
)
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||
@@ -912,33 +915,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
|
||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
||||
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
@@ -1020,7 +997,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
|
||||
image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
@@ -1028,12 +1005,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
# that is being composited over the original image,
|
||||
# we need to keep the original image around
|
||||
# and use it in the composite step.
|
||||
original_denoised_image = image.copy()
|
||||
|
||||
if p.paste_to is not None:
|
||||
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, overlay_image)
|
||||
image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
|
||||
42
modules/processing_scripts/comments.py
Normal file
42
modules/processing_scripts/comments.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from modules import scripts, shared, script_callbacks
|
||||
import re
|
||||
|
||||
|
||||
def strip_comments(text):
|
||||
text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # while line comment
|
||||
text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class ScriptStripComments(scripts.Script):
|
||||
def title(self):
|
||||
return "Comments"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def process(self, p, *args):
|
||||
if not shared.opts.enable_prompt_comments:
|
||||
return
|
||||
|
||||
p.all_prompts = [strip_comments(x) for x in p.all_prompts]
|
||||
p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts]
|
||||
|
||||
p.main_prompt = strip_comments(p.main_prompt)
|
||||
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
|
||||
|
||||
|
||||
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
|
||||
if not shared.opts.enable_prompt_comments:
|
||||
return
|
||||
|
||||
params.prompt = strip_comments(params.prompt)
|
||||
|
||||
|
||||
script_callbacks.on_before_token_counter(before_token_counter)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
"enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
|
||||
}))
|
||||
@@ -1,3 +1,4 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
from collections import namedtuple
|
||||
@@ -106,6 +107,15 @@ class ImageGridLoopParams:
|
||||
self.rows = rows
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BeforeTokenCounterParams:
|
||||
prompt: str
|
||||
steps: int
|
||||
styles: list
|
||||
|
||||
is_positive: bool = True
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
@@ -128,6 +138,7 @@ callback_map = dict(
|
||||
callbacks_on_reload=[],
|
||||
callbacks_list_optimizers=[],
|
||||
callbacks_list_unets=[],
|
||||
callbacks_before_token_counter=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -309,6 +320,14 @@ def list_unets_callback():
|
||||
return res
|
||||
|
||||
|
||||
def before_token_counter_callback(params: BeforeTokenCounterParams):
|
||||
for c in callback_map['callbacks_before_token_counter']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_token_counter')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
@@ -483,3 +502,10 @@ def on_list_unets(callback):
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||
|
||||
|
||||
def on_before_token_counter(callback):
|
||||
"""register a function to be called when UI is counting tokens for a prompt.
|
||||
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_token_counter'], callback)
|
||||
|
||||
@@ -939,22 +939,34 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||
|
||||
def set_named_arg(self, args, script_type, arg_elem_id, value):
|
||||
script = next((x for x in self.scripts if type(x).__name__ == script_type), None)
|
||||
def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
|
||||
"""Locate an arg of a specific script in script_args and set its value
|
||||
Args:
|
||||
args: all script args of process p, p.script_args
|
||||
script_name: the name target script name to
|
||||
arg_elem_id: the elem_id of the target arg
|
||||
value: the value to set
|
||||
fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
|
||||
Returns:
|
||||
Updated script args
|
||||
when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
|
||||
"""
|
||||
script = next((x for x in self.scripts if x.name == script_name), None)
|
||||
if script is None:
|
||||
return
|
||||
raise RuntimeError(f"script {script_name} not found")
|
||||
|
||||
for i, control in enumerate(script.controls):
|
||||
if arg_elem_id in control.elem_id:
|
||||
if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
|
||||
index = script.args_from + i
|
||||
|
||||
if isinstance(args, list):
|
||||
if isinstance(args, tuple):
|
||||
return args[:index] + (value,) + args[index + 1:]
|
||||
elif isinstance(args, list):
|
||||
args[index] = value
|
||||
return args
|
||||
elif isinstance(args, tuple):
|
||||
return args[:index] + (value,) + args[index+1:]
|
||||
else:
|
||||
return None
|
||||
raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
|
||||
raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
|
||||
70
modules/sd_emphasis.py
Normal file
70
modules/sd_emphasis.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
class Emphasis:
|
||||
"""Emphasis class decides how to death with (emphasized:1.1) text in prompts"""
|
||||
|
||||
name: str = "Base"
|
||||
description: str = ""
|
||||
|
||||
tokens: list[list[int]]
|
||||
"""tokens from the chunk of the prompt"""
|
||||
|
||||
multipliers: torch.Tensor
|
||||
"""tensor with multipliers, once for each token"""
|
||||
|
||||
z: torch.Tensor
|
||||
"""output of cond transformers network (CLIP)"""
|
||||
|
||||
def after_transformers(self):
|
||||
"""Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmphasisNone(Emphasis):
|
||||
name = "None"
|
||||
description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
|
||||
|
||||
|
||||
class EmphasisIgnore(Emphasis):
|
||||
name = "Ignore"
|
||||
description = "treat all empasised words as if they have no emphasis"
|
||||
|
||||
|
||||
class EmphasisOriginal(Emphasis):
|
||||
name = "Original"
|
||||
description = "the orginal emphasis implementation"
|
||||
|
||||
def after_transformers(self):
|
||||
original_mean = self.z.mean()
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
new_mean = self.z.mean()
|
||||
self.z = self.z * (original_mean / new_mean)
|
||||
|
||||
|
||||
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
||||
name = "No norm"
|
||||
description = "same as orginal, but without normalization (seems to work better for SDXL)"
|
||||
|
||||
def after_transformers(self):
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
|
||||
def get_current_option(emphasis_option_name):
|
||||
return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
|
||||
|
||||
|
||||
def get_options_descriptions():
|
||||
return ", ".join(f"{x.name}: {x.description}" for x in options)
|
||||
|
||||
|
||||
options = [
|
||||
EmphasisNone,
|
||||
EmphasisIgnore,
|
||||
EmphasisOriginal,
|
||||
EmphasisOriginalNoNorm,
|
||||
]
|
||||
@@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices, sd_hijack
|
||||
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
Returns the list and the total number of tokens in the prompt.
|
||||
"""
|
||||
|
||||
if opts.enable_emphasis:
|
||||
if opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
@@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
@@ -274,12 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
|
||||
pooled = getattr(z, 'pooled', None)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
emphasis = sd_emphasis.get_current_option(opts.emphasis)()
|
||||
emphasis.tokens = remade_batch_tokens
|
||||
emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
emphasis.z = z
|
||||
|
||||
emphasis.after_transformers()
|
||||
|
||||
z = emphasis.z
|
||||
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
||||
@@ -32,7 +32,7 @@ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase,
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||
mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
|
||||
@@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
from modules.shared import opts
|
||||
import tomesd
|
||||
import numpy as np
|
||||
|
||||
@@ -427,6 +428,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
apply_alpha_schedule_override(model)
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
@@ -550,6 +553,48 @@ def repair_config(sd_config):
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def apply_alpha_schedule_override(sd_model, p=None):
|
||||
"""
|
||||
Applies an override to the alpha schedule of the model according to settings.
|
||||
- downcasts the alpha schedule to half precision
|
||||
- rescales the alpha schedule to have zero terminal SNR
|
||||
"""
|
||||
|
||||
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
return
|
||||
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
|
||||
@@ -53,6 +53,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.padded_cond_uncond_v0 = False
|
||||
self.sampler = sampler
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
@@ -91,11 +92,67 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def pad_cond_uncond(self, cond, uncond):
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
cond = pad_cond(cond, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def pad_cond_uncond_v0(self, cond, uncond):
|
||||
"""
|
||||
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
||||
|
||||
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
||||
If 'uncond' is a tensor, it is padded directly.
|
||||
|
||||
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
||||
is repeated to match the number of columns in 'cond'.
|
||||
|
||||
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
||||
to match the number of columns in 'cond'.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
||||
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
||||
|
||||
Note:
|
||||
This is the padding that was always used in DDIM before version 1.6.0
|
||||
"""
|
||||
|
||||
is_dict_cond = isinstance(uncond, dict)
|
||||
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
||||
|
||||
if uncond_vec.shape[1] < cond.shape[1]:
|
||||
last_vector = uncond_vec[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
||||
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
||||
self.padded_cond_uncond_v0 = True
|
||||
elif uncond_vec.shape[1] > cond.shape[1]:
|
||||
uncond_vec = uncond_vec[:, :cond.shape[1]]
|
||||
self.padded_cond_uncond_v0 = True
|
||||
|
||||
if is_dict_cond:
|
||||
uncond['crossattn'] = uncond_vec
|
||||
else:
|
||||
uncond = uncond_vec
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if sd_samplers_common.apply_refiner(self):
|
||||
if sd_samplers_common.apply_refiner(self, sigma):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
@@ -162,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
self.padded_cond_uncond_v0 = False
|
||||
if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
|
||||
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
|
||||
elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
|
||||
@@ -155,8 +155,19 @@ def replace_torchsde_browinan():
|
||||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
def apply_refiner(cfg_denoiser, sigma=None):
|
||||
if opts.refiner_switch_by_sample_steps or not sigma:
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
cfg_denoiser.p.extra_generation_params["Refiner switch by sampling steps"] = True
|
||||
|
||||
else:
|
||||
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
|
||||
try:
|
||||
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
|
||||
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
|
||||
timestep = torch.max(sigma).to(dtype=int)
|
||||
completed_ratio = (999 - timestep) / 1000
|
||||
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
@@ -335,3 +346,10 @@ class Sampler:
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_infotext(self, p):
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond_v0:
|
||||
p.extra_generation_params["Pad conds v0"] = True
|
||||
|
||||
@@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@@ -133,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -158,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
}
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
@@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
|
||||
|
||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||
parallel_processing_allowed = True
|
||||
styles_filename = cmd_opts.styles_file
|
||||
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util, sd_emphasis
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||
@@ -154,7 +154,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"emphasis": OptionInfo("Original", "Emphasis mode", gr.Radio, lambda: {"choices": [x.name for x in sd_emphasis.options]}, infotext="Emphasis").info("makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; " + sd_emphasis.get_options_descriptions()),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
@@ -209,7 +209,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; overrides the above if set; WARNING: truncates negative prompt if it's too long; changes seeds"),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
@@ -225,7 +226,8 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
|
||||
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod"),
|
||||
"refiner_switch_by_sample_steps": OptionInfo(False, "Switch to refiner by sampling steps instead of model timesteps. Old behavior for refiner.", infotext="Refiner switch by sampling steps")
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||
@@ -252,9 +254,11 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"),
|
||||
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||
"extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(),
|
||||
"extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
@@ -268,7 +272,8 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
|
||||
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
||||
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters"),
|
||||
"include_styles_into_token_counters": OptionInfo(True, "Count tokens of enabled styles").info("When calculating how many tokens the prompt has, also consider tokens added by enabled styles."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
||||
@@ -281,6 +286,7 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
||||
"sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||
"sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
||||
"open_dir_button_choice": OptionInfo("Subdirectory", "What directory the [📂] button opens", gr.Radio, {"choices": ["Output Root", "Subdirectory", "Subdirectory (even temp dir)"]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from pathlib import Path
|
||||
from modules import errors
|
||||
import csv
|
||||
import fnmatch
|
||||
import os
|
||||
import os.path
|
||||
import typing
|
||||
import shutil
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
name: str
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
path: str = None
|
||||
prompt: str | None
|
||||
negative_prompt: str | None
|
||||
path: str | None = None
|
||||
|
||||
|
||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||
@@ -79,14 +79,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||
|
||||
|
||||
class StyleDatabase:
|
||||
def __init__(self, path: str):
|
||||
def __init__(self, paths: list[str | Path]):
|
||||
self.no_style = PromptStyle("None", "", "", None)
|
||||
self.styles = {}
|
||||
self.path = path
|
||||
self.paths = paths
|
||||
self.all_styles_files: list[Path] = []
|
||||
|
||||
folder, file = os.path.split(self.path)
|
||||
filename, _, ext = file.partition('*')
|
||||
self.default_path = os.path.join(folder, filename + ext)
|
||||
folder, file = os.path.split(self.paths[0])
|
||||
if '*' in file or '?' in file:
|
||||
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
||||
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
||||
self.paths.insert(0, self.default_path)
|
||||
else:
|
||||
self.default_path = Path(self.paths[0])
|
||||
|
||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||
|
||||
@@ -99,57 +104,58 @@ class StyleDatabase:
|
||||
"""
|
||||
self.styles.clear()
|
||||
|
||||
path, filename = os.path.split(self.path)
|
||||
# scans for all styles files
|
||||
all_styles_files = []
|
||||
for pattern in self.paths:
|
||||
folder, file = os.path.split(pattern)
|
||||
if '*' in file or '?' in file:
|
||||
found_files = Path(folder).glob(file)
|
||||
[all_styles_files.append(file) for file in found_files]
|
||||
else:
|
||||
# if os.path.exists(pattern):
|
||||
all_styles_files.append(Path(pattern))
|
||||
|
||||
if "*" in filename:
|
||||
fileglob = filename.split("*")[0] + "*.csv"
|
||||
filelist = []
|
||||
for file in os.listdir(path):
|
||||
if fnmatch.fnmatch(file, fileglob):
|
||||
filelist.append(file)
|
||||
# Add a visible divider to the style list
|
||||
half_len = round(len(file) / 2)
|
||||
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
||||
divider = f"{divider} {'-' * (40 - len(divider))}"
|
||||
self.styles[divider] = PromptStyle(
|
||||
f"{divider}", None, None, "do_not_save"
|
||||
# Remove any duplicate entries
|
||||
seen = set()
|
||||
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
||||
|
||||
for styles_file in self.all_styles_files:
|
||||
if len(all_styles_files) > 1:
|
||||
# add divider when more than styles file
|
||||
# '---------------- STYLES ----------------'
|
||||
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
||||
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
||||
if styles_file.is_file():
|
||||
self.load_from_csv(styles_file)
|
||||
|
||||
def load_from_csv(self, path: str | Path):
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||
reader = csv.DictReader(file, skipinitialspace=True)
|
||||
for row in reader:
|
||||
# Ignore empty rows or rows starting with a comment
|
||||
if not row or row["name"].startswith("#"):
|
||||
continue
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||
negative_prompt = row.get("negative_prompt", "")
|
||||
# Add style to database
|
||||
self.styles[row["name"]] = PromptStyle(
|
||||
row["name"], prompt, negative_prompt, str(path)
|
||||
)
|
||||
# Add styles from this CSV file
|
||||
self.load_from_csv(os.path.join(path, file))
|
||||
if len(filelist) == 0:
|
||||
print(f"No styles found in {path} matching {fileglob}")
|
||||
return
|
||||
elif not os.path.exists(self.path):
|
||||
print(f"Style database not found: {self.path}")
|
||||
return
|
||||
else:
|
||||
self.load_from_csv(self.path)
|
||||
|
||||
def load_from_csv(self, path: str):
|
||||
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||
reader = csv.DictReader(file, skipinitialspace=True)
|
||||
for row in reader:
|
||||
# Ignore empty rows or rows starting with a comment
|
||||
if not row or row["name"].startswith("#"):
|
||||
continue
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||
negative_prompt = row.get("negative_prompt", "")
|
||||
# Add style to database
|
||||
self.styles[row["name"]] = PromptStyle(
|
||||
row["name"], prompt, negative_prompt, path
|
||||
)
|
||||
except Exception:
|
||||
errors.report(f'Error loading styles from {path}: ', exc_info=True)
|
||||
|
||||
def get_style_paths(self) -> set:
|
||||
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
||||
# Update any styles without a path to the default path
|
||||
for style in list(self.styles.values()):
|
||||
if not style.path:
|
||||
self.styles[style.name] = style._replace(path=self.default_path)
|
||||
self.styles[style.name] = style._replace(path=str(self.default_path))
|
||||
|
||||
# Create a list of all distinct paths, including the default path
|
||||
style_paths = set()
|
||||
style_paths.add(self.default_path)
|
||||
style_paths.add(str(self.default_path))
|
||||
for _, style in self.styles.items():
|
||||
if style.path:
|
||||
style_paths.add(style.path)
|
||||
@@ -177,7 +183,6 @@ class StyleDatabase:
|
||||
|
||||
def save_styles(self, path: str = None) -> None:
|
||||
# The path argument is deprecated, but kept for backwards compatibility
|
||||
_ = path
|
||||
|
||||
style_paths = self.get_style_paths()
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@ class EmbeddingDatabase:
|
||||
return embedding
|
||||
|
||||
def get_expected_shape(self):
|
||||
devices.torch_npu_set_device()
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
|
||||
@@ -60,10 +60,10 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
|
||||
assert len(gallery) > 0, 'No image to upscale'
|
||||
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
|
||||
|
||||
p = txt2img_create_processing(id_task, request, *args)
|
||||
p.enable_hr = True
|
||||
p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
|
||||
p.batch_size = 1
|
||||
p.n_iter = 1
|
||||
# txt2img_upscale attribute that signifies this is called by txt2img_upscale
|
||||
p.txt2img_upscale = True
|
||||
|
||||
geninfo = json.loads(generation_info)
|
||||
|
||||
@@ -152,7 +152,18 @@ def connect_clear_prompt(button):
|
||||
)
|
||||
|
||||
|
||||
def update_token_counter(text, steps, *, is_positive=True):
|
||||
def update_token_counter(text, steps, styles, *, is_positive=True):
|
||||
params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
|
||||
script_callbacks.before_token_counter_callback(params)
|
||||
text = params.prompt
|
||||
steps = params.steps
|
||||
styles = params.styles
|
||||
is_positive = params.is_positive
|
||||
|
||||
if shared.opts.include_styles_into_token_counters:
|
||||
apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
|
||||
text = apply_styles(text, styles)
|
||||
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
|
||||
@@ -174,8 +185,8 @@ def update_token_counter(text, steps, *, is_positive=True):
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
def update_negative_prompt_token_counter(text, steps):
|
||||
return update_token_counter(text, steps, is_positive=False)
|
||||
def update_negative_prompt_token_counter(*args):
|
||||
return update_token_counter(*args, is_positive=False)
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
@@ -487,8 +498,10 @@ def create_ui():
|
||||
height,
|
||||
]
|
||||
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
|
||||
@@ -812,8 +825,10 @@ def create_ui():
|
||||
**interrogate_args,
|
||||
)
|
||||
|
||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||
|
||||
img2img_paste_fields = [
|
||||
(toprow.prompt, "Prompt"),
|
||||
@@ -851,7 +866,7 @@ def create_ui():
|
||||
ui_postprocessing.create_ui()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||
with gr.Row(equal_height=False):
|
||||
with ResizeHandleRow(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
||||
|
||||
@@ -879,7 +894,7 @@ def create_ui():
|
||||
with gr.Row(equal_height=False):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
||||
with gr.Row(variant="compact", equal_height=False):
|
||||
with ResizeHandleRow(variant="compact", equal_height=False):
|
||||
with gr.Tabs(elem_id="train_tabs"):
|
||||
|
||||
with gr.Tab(label="Create embedding", id="create_embedding"):
|
||||
|
||||
@@ -10,7 +10,8 @@ import gradio as gr
|
||||
import subprocess as sp
|
||||
from PIL import Image
|
||||
|
||||
from modules import call_queue, shared
|
||||
from modules import call_queue, shared, ui_tempdir
|
||||
from modules.infotext_utils import image_from_url_text
|
||||
import modules.images
|
||||
from modules.ui_components import ToolButton
|
||||
import modules.infotext_utils as parameters_copypaste
|
||||
@@ -167,29 +168,43 @@ class OutputPanel:
|
||||
def create_output_panel(tabname, outdir, toprow=None):
|
||||
res = OutputPanel()
|
||||
|
||||
def open_folder(f):
|
||||
def open_folder(f, images=None, index=None):
|
||||
if shared.cmd_opts.hide_ui_dir_config:
|
||||
return
|
||||
|
||||
try:
|
||||
if 'Sub' in shared.opts.open_dir_button_choice:
|
||||
image_dir = os.path.split(images[index]["name"].rsplit('?', 1)[0])[0]
|
||||
if 'temp' in shared.opts.open_dir_button_choice or not ui_tempdir.is_gradio_temp_path(image_dir):
|
||||
f = image_dir
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not os.path.exists(f):
|
||||
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||
msg = f'Folder "{f}" does not exist. After you create an image, the folder will be created.'
|
||||
print(msg)
|
||||
gr.Info(msg)
|
||||
return
|
||||
elif not os.path.isdir(f):
|
||||
print(f"""
|
||||
msg = f"""
|
||||
WARNING
|
||||
An open_folder request was made with an argument that is not a folder.
|
||||
This could be an error or a malicious attempt to run code on your computer.
|
||||
Requested path was: {f}
|
||||
""", file=sys.stderr)
|
||||
"""
|
||||
print(msg, file=sys.stderr)
|
||||
gr.Warning(msg)
|
||||
return
|
||||
|
||||
if not shared.cmd_opts.hide_ui_dir_config:
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||
sp.Popen(["wsl-open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
path = os.path.normpath(f)
|
||||
if platform.system() == "Windows":
|
||||
os.startfile(path)
|
||||
elif platform.system() == "Darwin":
|
||||
sp.Popen(["open", path])
|
||||
elif "microsoft-standard-WSL2" in platform.uname().release:
|
||||
sp.Popen(["wsl-open", path])
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(elem_id=f"{tabname}_results"):
|
||||
if toprow:
|
||||
@@ -216,8 +231,12 @@ Requested path was: {f}
|
||||
res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
fn=lambda images, index: open_folder(shared.opts.outdir_samples or outdir, images, index),
|
||||
_js="(y, w) => [y, selected_gallery_index()]",
|
||||
inputs=[
|
||||
res.gallery,
|
||||
open_folder_button, # placeholder for index
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -548,6 +548,7 @@ def create_ui():
|
||||
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
|
||||
refresh = gr.Button(value='Refresh', variant="compact")
|
||||
|
||||
html = ""
|
||||
|
||||
@@ -566,7 +567,8 @@ def create_ui():
|
||||
with gr.Row(elem_classes="progress-container"):
|
||||
extensions_table = gr.HTML('Loading...', elem_id="extensions_installed_html")
|
||||
|
||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table], show_progress=False)
|
||||
refresh.click(fn=extension_table, inputs=[], outputs=[extensions_table], show_progress=False)
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
|
||||
@@ -134,8 +134,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
||||
errors.display(e, "creating item for extra network")
|
||||
item = page.items.get(name)
|
||||
|
||||
page.read_user_metadata(item)
|
||||
item_html = page.create_item_html(tabname, item)
|
||||
page.read_user_metadata(item, use_cache=False)
|
||||
item_html = page.create_item_html(tabname, item, shared.html("extra-networks-card.html"))
|
||||
|
||||
return JSONResponse({"html": item_html})
|
||||
|
||||
@@ -173,9 +173,9 @@ class ExtraNetworksPage:
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def read_user_metadata(self, item):
|
||||
def read_user_metadata(self, item, use_cache=True):
|
||||
filename = item.get("filename", None)
|
||||
metadata = extra_networks.get_user_metadata(filename, lister=self.lister)
|
||||
metadata = extra_networks.get_user_metadata(filename, lister=self.lister if use_cache else None)
|
||||
|
||||
desc = metadata.get("description", None)
|
||||
if desc is not None:
|
||||
@@ -289,12 +289,16 @@ class ExtraNetworksPage:
|
||||
}
|
||||
)
|
||||
|
||||
description = (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else "")
|
||||
if not shared.opts.extra_networks_card_description_is_html:
|
||||
description = html.escape(description)
|
||||
|
||||
# Some items here might not be used depending on HTML template used.
|
||||
args = {
|
||||
"background_image": background_image,
|
||||
"card_clicked": onclick,
|
||||
"copy_path_button": btn_copy_path,
|
||||
"description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||
"description": description,
|
||||
"edit_button": btn_edit_item,
|
||||
"local_preview": quote_js(item["local_preview"]),
|
||||
"metadata_button": btn_metadata,
|
||||
@@ -472,7 +476,7 @@ class ExtraNetworksPage:
|
||||
|
||||
return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
|
||||
|
||||
def create_card_view_html(self, tabname: str) -> str:
|
||||
def create_card_view_html(self, tabname: str, *, none_message) -> str:
|
||||
"""Generates HTML for the network Card View section for a tab.
|
||||
|
||||
This HTML goes into the `extra-networks-pane.html` <div> with
|
||||
@@ -480,6 +484,7 @@ class ExtraNetworksPage:
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
none_message: HTML text to show when there are no cards.
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
@@ -490,24 +495,28 @@ class ExtraNetworksPage:
|
||||
|
||||
if res == "":
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
res = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
res = none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
return res
|
||||
|
||||
def create_html(self, tabname):
|
||||
def create_html(self, tabname, *, empty=False):
|
||||
"""Generates an HTML string for the current pane.
|
||||
|
||||
The generated HTML uses `extra-networks-pane.html` as a template.
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
empty: create an empty HTML page with no items
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
"""
|
||||
self.lister.reset()
|
||||
self.metadata = {}
|
||||
self.items = {x["name"]: x for x in self.list_items()}
|
||||
|
||||
items_list = [] if empty else self.list_items()
|
||||
self.items = {x["name"]: x for x in items_list}
|
||||
|
||||
# Populate the instance metadata for each item.
|
||||
for item in self.items.values():
|
||||
metadata = item.get("metadata")
|
||||
@@ -522,9 +531,13 @@ class ExtraNetworksPage:
|
||||
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
|
||||
tree_view_btn_extra_class = ""
|
||||
tree_view_div_extra_class = "hidden"
|
||||
tree_view_div_default_display = "none"
|
||||
extra_network_pane_content_default_display = "flex"
|
||||
if shared.opts.extra_networks_tree_view_default_enabled:
|
||||
tree_view_btn_extra_class = "extra-network-control--enabled"
|
||||
tree_view_div_extra_class = ""
|
||||
tree_view_div_default_display = "block"
|
||||
extra_network_pane_content_default_display = "grid"
|
||||
|
||||
return self.pane_tpl.format(
|
||||
**{
|
||||
@@ -536,7 +549,10 @@ class ExtraNetworksPage:
|
||||
"tree_view_btn_extra_class": tree_view_btn_extra_class,
|
||||
"tree_view_div_extra_class": tree_view_div_extra_class,
|
||||
"tree_html": self.create_tree_view_html(tabname),
|
||||
"items_html": self.create_card_view_html(tabname),
|
||||
"items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None),
|
||||
"extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width,
|
||||
"tree_view_div_default_display": tree_view_div_default_display,
|
||||
"extra_network_pane_content_default_display": extra_network_pane_content_default_display,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -559,7 +575,7 @@ class ExtraNetworksPage:
|
||||
"date_created": int(mtime),
|
||||
"date_modified": int(ctime),
|
||||
"name": pth.name.lower(),
|
||||
"path": str(pth.parent).lower(),
|
||||
"path": str(pth).lower(),
|
||||
}
|
||||
|
||||
def find_preview(self, path):
|
||||
@@ -638,6 +654,7 @@ def pages_in_preferred_order(pages):
|
||||
|
||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||
|
||||
|
||||
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
@@ -648,15 +665,13 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
|
||||
related_tabs = []
|
||||
|
||||
button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False)
|
||||
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab:
|
||||
with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]):
|
||||
pass
|
||||
|
||||
elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
page_elem = gr.HTML(page.create_html(tabname, empty=True), elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
editor = page.create_user_metadata_editor(ui, tabname)
|
||||
@@ -680,6 +695,15 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
)
|
||||
tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
def refresh():
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
create_html()
|
||||
return ui.pages_contents
|
||||
|
||||
button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal", visible=False)
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js="function(){ " + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + " }").then(fn=lambda: None, _js='setupAllResizeHandles')
|
||||
|
||||
def create_html():
|
||||
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||
|
||||
@@ -688,14 +712,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
create_html()
|
||||
return ui.pages_contents
|
||||
|
||||
def refresh():
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
create_html()
|
||||
return ui.pages_contents
|
||||
|
||||
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]).then(fn=None, js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}')
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
interface.load(fn=pages_html, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js='setupAllResizeHandles')
|
||||
|
||||
return ui
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_terms": search_terms,
|
||||
"onclick": html.escape(f"return selectCheckpoint('{name}');"),
|
||||
"onclick": html.escape(f"return selectCheckpoint({ui_extra_networks.quote_js(name)})"),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": checkpoint.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
|
||||
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
|
||||
if not name:
|
||||
return gr.update(visible=False)
|
||||
|
||||
style = styles.PromptStyle(name, prompt, negative_prompt)
|
||||
existing_style = shared.prompt_styles.styles.get(name)
|
||||
path = existing_style.path if existing_style is not None else None
|
||||
|
||||
style = styles.PromptStyle(name, prompt, negative_prompt, path)
|
||||
shared.prompt_styles.styles[style.name] = style
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
shared.prompt_styles.save_styles()
|
||||
|
||||
return gr.update(visible=True)
|
||||
|
||||
@@ -34,7 +37,7 @@ def delete_style(name):
|
||||
return
|
||||
|
||||
shared.prompt_styles.styles.pop(name, None)
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
shared.prompt_styles.save_styles()
|
||||
|
||||
return '', '', ''
|
||||
|
||||
|
||||
@@ -46,12 +46,9 @@ def save_pil_to_file(pil_image, cache_dir=None, format="png"):
|
||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
register_tmp_file(shared.demo, already_saved_as)
|
||||
filename = already_saved_as
|
||||
|
||||
if not shared.opts.save_images_add_number:
|
||||
filename += f'?{os.path.getmtime(already_saved_as)}'
|
||||
|
||||
return filename
|
||||
filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
|
||||
register_tmp_file(shared.demo, filename_with_mtime)
|
||||
return filename_with_mtime
|
||||
|
||||
if shared.opts.temp_dir:
|
||||
dir = shared.opts.temp_dir
|
||||
@@ -179,3 +176,18 @@ def cleanup_tmpdr():
|
||||
|
||||
filename = os.path.join(root, name)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def is_gradio_temp_path(path):
|
||||
"""
|
||||
Check if the path is a temp dir used by gradio
|
||||
"""
|
||||
path = Path(path)
|
||||
if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
|
||||
return True
|
||||
if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
|
||||
if path.is_relative_to(gradio_temp_dir):
|
||||
return True
|
||||
if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -17,6 +17,7 @@ class Toprow:
|
||||
button_deepbooru = None
|
||||
|
||||
interrupt = None
|
||||
interrupting = None
|
||||
skip = None
|
||||
submit = None
|
||||
|
||||
@@ -96,15 +97,10 @@ class Toprow:
|
||||
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
|
||||
self.submit_box = submit_box
|
||||
|
||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
|
||||
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
|
||||
|
||||
self.skip.click(
|
||||
fn=lambda: shared.state.skip(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt", tooltip="End generation immediately or after completing current batch")
|
||||
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip", tooltip="Stop generation of current batch and continues onto next batch")
|
||||
self.interrupting = gr.Button('Interrupting...', elem_id=f"{self.id_part}_interrupting", elem_classes="generate-box-interrupting", tooltip="Interrupting generation...")
|
||||
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary', tooltip="Right click generate forever menu")
|
||||
|
||||
def interrupt_function():
|
||||
if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:
|
||||
@@ -113,11 +109,9 @@ class Toprow:
|
||||
else:
|
||||
shared.state.interrupt()
|
||||
|
||||
self.interrupt.click(
|
||||
fn=interrupt_function,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
self.skip.click(fn=shared.state.skip)
|
||||
self.interrupt.click(fn=interrupt_function, _js='function(){ showSubmitInterruptingPlaceholder("' + self.id_part + '"); }')
|
||||
self.interrupting.click(fn=interrupt_function)
|
||||
|
||||
def create_tools_row(self):
|
||||
with gr.Row(elem_id=f"{self.id_part}_tools"):
|
||||
@@ -133,9 +127,9 @@ class Toprow:
|
||||
|
||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
|
||||
|
||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
|
||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"], visible=False)
|
||||
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
|
||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"], visible=False)
|
||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
|
||||
|
||||
self.clear_prompt_button.click(
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from modules import images, shared, torch_utils
|
||||
from modules import devices, images, shared, torch_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,7 +44,8 @@ def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
|
||||
with torch.no_grad():
|
||||
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
|
||||
tensor = tensor.to(device=param.device, dtype=param.dtype)
|
||||
return torch_bgr_to_pil_image(model(tensor))
|
||||
with devices.without_autocast():
|
||||
return torch_bgr_to_pil_image(model(tensor))
|
||||
|
||||
|
||||
def upscale_with_model(
|
||||
|
||||
@@ -42,7 +42,7 @@ def walk_files(path, allowed_extensions=None):
|
||||
for filename in sorted(files, key=natural_sort_key):
|
||||
if allowed_extensions is not None:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext not in allowed_extensions:
|
||||
if ext.lower() not in allowed_extensions:
|
||||
continue
|
||||
|
||||
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||
|
||||
Reference in New Issue
Block a user