mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-04-25 00:39:36 +00:00
Merge branch 'dev' into fix-11805
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
@@ -14,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
@@ -22,7 +23,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||
from modules.sd_vae import vae_dict
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
@@ -30,13 +31,7 @@ from modules import devices
|
||||
from typing import Dict, List, Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||
from contextlib import closing
|
||||
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
@@ -84,6 +79,8 @@ def encode_pil_to_base64(image):
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
@@ -102,14 +99,16 @@ def encode_pil_to_base64(image):
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
rich_available = False
|
||||
try:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
rich_available = True
|
||||
except Exception:
|
||||
rich_available = False
|
||||
pass
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
@@ -120,14 +119,14 @@ def api_middleware(app: FastAPI):
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code = res.status_code,
|
||||
ver = req.scope.get('http_version', '0.0'),
|
||||
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot = req.scope.get('scheme', 'err'),
|
||||
method = req.scope.get('method', 'err'),
|
||||
endpoint = endpoint,
|
||||
duration = duration,
|
||||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code=res.status_code,
|
||||
ver=req.scope.get('http_version', '0.0'),
|
||||
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot=req.scope.get('scheme', 'err'),
|
||||
method=req.scope.get('method', 'err'),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
))
|
||||
return res
|
||||
|
||||
@@ -138,7 +137,7 @@ def api_middleware(app: FastAPI):
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
@@ -209,6 +208,11 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
if shared.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
@@ -324,19 +328,19 @@ class Api:
|
||||
args.pop('save_images', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@@ -380,20 +384,20 @@ class Api:
|
||||
args.pop('save_images', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@@ -517,6 +521,10 @@ class Api:
|
||||
return options
|
||||
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
|
||||
@@ -598,44 +606,42 @@ class Api:
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="preprocess")
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info = 'preprocess complete')
|
||||
return models.PreprocessResponse(info='preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
except Exception as e:
|
||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||
except FileNotFoundError as e:
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_embedding")
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
@@ -648,15 +654,15 @@ class Api:
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
except Exception as msg:
|
||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_hypernetwork")
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
@@ -674,9 +680,10 @@ class Api:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError:
|
||||
except Exception as exc:
|
||||
return models.TrainResponse(info=f"train embedding error: {exc}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
@@ -715,4 +722,17 @@ class Api:
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
|
||||
|
||||
def kill_webui(self):
|
||||
restart.stop_program()
|
||||
|
||||
def restart_webui(self):
|
||||
if restart.is_restartable():
|
||||
restart.restart_program()
|
||||
return Response(status_code=501)
|
||||
|
||||
def stop_webui(request):
|
||||
shared.state.server_command = "stop"
|
||||
return Response("Stopping.")
|
||||
|
||||
|
||||
@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
|
||||
120
modules/cache.py
Normal file
120
modules/cache.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import json
|
||||
import os.path
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules.paths import data_path, script_path
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_data = None
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
dump_cache_after = None
|
||||
dump_cache_thread = None
|
||||
|
||||
|
||||
def dump_cache():
|
||||
"""
|
||||
Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
|
||||
"""
|
||||
|
||||
global dump_cache_after
|
||||
global dump_cache_thread
|
||||
|
||||
def thread_func():
|
||||
global dump_cache_after
|
||||
global dump_cache_thread
|
||||
|
||||
while dump_cache_after is not None and time.time() < dump_cache_after:
|
||||
time.sleep(1)
|
||||
|
||||
with cache_lock:
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
dump_cache_after = None
|
||||
dump_cache_thread = None
|
||||
|
||||
with cache_lock:
|
||||
dump_cache_after = time.time() + 5
|
||||
if dump_cache_thread is None:
|
||||
dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
|
||||
dump_cache_thread.start()
|
||||
|
||||
|
||||
def cache(subsection):
|
||||
"""
|
||||
Retrieves or initializes a cache for a specific subsection.
|
||||
|
||||
Parameters:
|
||||
subsection (str): The subsection identifier for the cache.
|
||||
|
||||
Returns:
|
||||
dict: The cache data for the specified subsection.
|
||||
"""
|
||||
|
||||
global cache_data
|
||||
|
||||
if cache_data is None:
|
||||
with cache_lock:
|
||||
if cache_data is None:
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
try:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
except Exception:
|
||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
||||
cache_data = {}
|
||||
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def cached_data_for_file(subsection, title, filename, func):
|
||||
"""
|
||||
Retrieves or generates data for a specific file, using a caching mechanism.
|
||||
|
||||
Parameters:
|
||||
subsection (str): The subsection of the cache to use.
|
||||
title (str): The title of the data entry in the subsection of the cache.
|
||||
filename (str): The path to the file to be checked for modifications.
|
||||
func (callable): A function that generates the data if it is not available in the cache.
|
||||
|
||||
Returns:
|
||||
dict or None: The cached or generated data, or None if data generation fails.
|
||||
|
||||
The `cached_data_for_file` function implements a caching mechanism for data stored in files.
|
||||
It checks if the data associated with the given `title` is present in the cache and compares the
|
||||
modification time of the file with the cached modification time. If the file has been modified,
|
||||
the cache is considered invalid and the data is regenerated using the provided `func`.
|
||||
Otherwise, the cached data is returned.
|
||||
|
||||
If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
|
||||
or cached data is returned as a dictionary.
|
||||
"""
|
||||
|
||||
existing_cache = cache(subsection)
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
|
||||
entry = existing_cache.get(title)
|
||||
if entry:
|
||||
cached_mtime = entry.get("mtime", 0)
|
||||
if ondisk_mtime > cached_mtime:
|
||||
entry = None
|
||||
|
||||
if not entry or 'value' not in entry:
|
||||
value = func()
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
entry = {'mtime': ondisk_mtime, 'value': value}
|
||||
existing_cache[title] = entry
|
||||
|
||||
dump_cache()
|
||||
|
||||
return entry['value']
|
||||
@@ -1,3 +1,4 @@
|
||||
from functools import wraps
|
||||
import html
|
||||
import threading
|
||||
import time
|
||||
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
@wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
|
||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job=id_task)
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
@wraps(func)
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
@@ -82,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
elapsed_text = f"{elapsed_s:.1f} sec."
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
elapsed_text = f"{elapsed_m} min. "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
@@ -92,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
sys_pct = sys_peak/max(sys_total, 1) * 100
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
||||
toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
|
||||
toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
|
||||
|
||||
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
||||
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
||||
text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
|
||||
|
||||
vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
|
||||
@@ -107,3 +107,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
|
||||
@@ -15,7 +15,6 @@ model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
have_codeformer = False
|
||||
codeformer = None
|
||||
|
||||
|
||||
@@ -100,7 +99,7 @@ def setup_model(dirname):
|
||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
@@ -123,9 +122,6 @@ def setup_model(dirname):
|
||||
|
||||
return restored_img
|
||||
|
||||
global have_codeformer
|
||||
have_codeformer = True
|
||||
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
@@ -15,13 +15,6 @@ def has_mps() -> bool:
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
@@ -56,11 +49,15 @@ def get_device_for(task):
|
||||
|
||||
|
||||
def torch_gc():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
if has_mps():
|
||||
mac_specific.torch_mps_gc()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import modelloader, images, devices
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
if file.startswith("http"):
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(
|
||||
if path.startswith("http"):
|
||||
# TODO: this doesn't use `path` at all?
|
||||
filename = modelloader.load_file_from_url(
|
||||
url=self.model_url,
|
||||
model_dir=self.model_download_path,
|
||||
file_name=f"{self.model_name}.pth",
|
||||
progress=True,
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"Unable to load {self.model_path} from {filename}")
|
||||
return None
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
from modules import shared, errors
|
||||
from modules import shared, errors, cache
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
|
||||
@@ -21,6 +21,7 @@ def active():
|
||||
|
||||
class Extension:
|
||||
lock = threading.Lock()
|
||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
@@ -36,15 +37,29 @@ class Extension:
|
||||
self.remote = None
|
||||
self.have_info_from_repo = False
|
||||
|
||||
def to_dict(self):
|
||||
return {x: getattr(self, x) for x in self.cached_fields}
|
||||
|
||||
def from_dict(self, d):
|
||||
for field in self.cached_fields:
|
||||
setattr(self, field, d[field])
|
||||
|
||||
def read_info_from_repo(self):
|
||||
if self.is_builtin or self.have_info_from_repo:
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
def read_from_repo():
|
||||
with self.lock:
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.do_read_info_from_repo()
|
||||
self.do_read_info_from_repo()
|
||||
|
||||
return self.to_dict()
|
||||
|
||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||
self.from_dict(d)
|
||||
self.status = 'unknown'
|
||||
|
||||
def do_read_info_from_repo(self):
|
||||
repo = None
|
||||
@@ -58,7 +73,6 @@ class Extension:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.status = 'unknown'
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
commit = repo.head.commit
|
||||
self.commit_date = commit.committed_date
|
||||
|
||||
@@ -4,16 +4,22 @@ from collections import defaultdict
|
||||
from modules import errors
|
||||
|
||||
extra_network_registry = {}
|
||||
extra_network_aliases = {}
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_network_registry.clear()
|
||||
extra_network_aliases.clear()
|
||||
|
||||
|
||||
def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
def register_extra_network_alias(extra_network, alias):
|
||||
extra_network_aliases[alias] = extra_network
|
||||
|
||||
|
||||
def register_default_extra_networks():
|
||||
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
||||
register_extra_network(ExtraNetworkHypernet())
|
||||
@@ -82,20 +88,26 @@ def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
activated = []
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
|
||||
if extra_network is None:
|
||||
extra_network = extra_network_aliases.get(extra_network_name, None)
|
||||
|
||||
if extra_network is None:
|
||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
activated.append(extra_network)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
if extra_network in activated:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -103,6 +115,9 @@ def activate(p, extra_network_data):
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
|
||||
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
|
||||
@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
|
||||
return img, w, h
|
||||
|
||||
|
||||
|
||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||
|
||||
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
||||
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
||||
|
||||
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
||||
"""
|
||||
hypernet_name = hypernet_name.lower()
|
||||
if hypernet_hash is not None:
|
||||
# Try to match the hash in the name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
result = re_hypernet_hash.search(hypernet_key)
|
||||
if result is not None and result[1] == hypernet_hash:
|
||||
return hypernet_key
|
||||
else:
|
||||
# Fall back to a hypernet with the same name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
if hypernet_key.lower().startswith(hypernet_name):
|
||||
return hypernet_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def restore_old_hires_fix_params(res):
|
||||
"""for infotexts that specify old First pass size parameter, convert it into
|
||||
width, height, and hr scale"""
|
||||
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
return res
|
||||
|
||||
|
||||
settings_map = {}
|
||||
|
||||
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
|
||||
@@ -25,7 +25,7 @@ def gfpgann():
|
||||
return None
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||
if len(models) == 1 and "http" in models[0]:
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
|
||||
@@ -1,38 +1,11 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os.path
|
||||
|
||||
import filelock
|
||||
|
||||
from modules import shared
|
||||
from modules.paths import data_path
|
||||
import modules.cache
|
||||
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_data = None
|
||||
|
||||
|
||||
def dump_cache():
|
||||
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
|
||||
def cache(subsection):
|
||||
global cache_data
|
||||
|
||||
if cache_data is None:
|
||||
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
|
||||
return s
|
||||
dump_cache = modules.cache.dump_cache
|
||||
cache = modules.cache.cache
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import html
|
||||
import os
|
||||
import inspect
|
||||
from contextlib import closing
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
@@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
if not search:
|
||||
return None
|
||||
search = search.lower()
|
||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||
if not applicable:
|
||||
return None
|
||||
applicable = sorted(applicable, key=lambda name: len(name))
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
@@ -388,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
@@ -446,18 +436,6 @@ def statistics(data):
|
||||
return total_information, recent_information
|
||||
|
||||
|
||||
def report_statistics(loss_info:dict):
|
||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||
for key in keys:
|
||||
try:
|
||||
print("Loss statistics for file " + key)
|
||||
info, recent = statistics(list(loss_info[key]))
|
||||
print(info)
|
||||
print(recent)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
@@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
with closing(p):
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
@@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
#report_statistics(loss_dict)
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
import pytz
|
||||
@@ -10,7 +12,7 @@ import re
|
||||
import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
@@ -139,6 +141,11 @@ class GridAnnotation:
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
||||
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
||||
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
||||
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
fnt = get_font(fontsize)
|
||||
|
||||
color_active = (0, 0, 0)
|
||||
color_inactive = (153, 153, 153)
|
||||
|
||||
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
||||
|
||||
cols = im.width // width
|
||||
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
||||
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
||||
|
||||
calc_img = Image.new("RGB", (1, 1), "white")
|
||||
calc_img = Image.new("RGB", (1, 1), color_background)
|
||||
calc_d = ImageDraw.Draw(calc_img)
|
||||
|
||||
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
||||
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
@@ -302,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
if fill_height > 0:
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
if fill_width > 0:
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
|
||||
return res
|
||||
|
||||
@@ -372,8 +378,9 @@ class FilenameGenerator:
|
||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||
'user': lambda self: self.p.user,
|
||||
'vae_filename': lambda self: self.get_vae_filename(),
|
||||
|
||||
'none': lambda self: '', # Overrides the default so you can get just the sequence number
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
@@ -497,13 +504,23 @@ def get_next_sequence_number(path, basename):
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
|
||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
||||
"""
|
||||
Saves image to filename, including geninfo as text information for generation info.
|
||||
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
||||
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
||||
"""
|
||||
|
||||
if extension is None:
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
if extension.lower() == '.png':
|
||||
existing_pnginfo = existing_pnginfo or {}
|
||||
if opts.enable_pnginfo:
|
||||
existing_pnginfo[pnginfo_section_name] = geninfo
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in (existing_pnginfo or {}).items():
|
||||
@@ -585,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
else:
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
add_number = opts.save_images_add_number or file_decoration == ''
|
||||
|
||||
if file_decoration != "" and add_number:
|
||||
file_decoration = f"-{file_decoration}"
|
||||
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
if add_number:
|
||||
basecount = get_next_sequence_number(path, basename)
|
||||
fullfn = None
|
||||
@@ -622,7 +639,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
"""
|
||||
temp_file_path = f"{filename_without_extension}.tmp"
|
||||
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
|
||||
@@ -639,12 +656,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
||||
ratio = image.width / image.height
|
||||
|
||||
resize_to = None
|
||||
if oversize and ratio > 1:
|
||||
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
||||
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
||||
elif oversize:
|
||||
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
||||
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
||||
|
||||
if resize_to is not None:
|
||||
try:
|
||||
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
||||
image = image.resize(resize_to, LANCZOS)
|
||||
except Exception:
|
||||
image = image.resize(resize_to)
|
||||
try:
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
except Exception as e:
|
||||
@@ -662,8 +685,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
def read_info_from_image(image):
|
||||
items = image.info or {}
|
||||
IGNORED_INFO_KEYS = {
|
||||
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||
'icc_profile', 'chromaticity', 'photoshop',
|
||||
}
|
||||
|
||||
|
||||
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
items = (image.info or {}).copy()
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
@@ -679,9 +709,7 @@ def read_info_from_image(image):
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||
'icc_profile', 'chromaticity']:
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
|
||||
@@ -1,23 +1,26 @@
|
||||
import os
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules import sd_samplers, images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.images import save_image
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = shared.listfiles(input_dir)
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
@@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
prompt = p.prompt
|
||||
negative_prompt = p.negative_prompt
|
||||
seed = p.seed
|
||||
cfg_scale = p.cfg_scale
|
||||
sampler_name = p.sampler_name
|
||||
steps = p.steps
|
||||
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
@@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
if use_png_info:
|
||||
try:
|
||||
info_img = img
|
||||
if png_info_dir:
|
||||
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||
info_img = Image.open(info_img_path)
|
||||
geninfo, _ = imgutil.read_info_from_image(info_img)
|
||||
parsed_parameters = parse_generation_parameters(geninfo)
|
||||
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||
except Exception:
|
||||
parsed_parameters = {}
|
||||
|
||||
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
||||
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
||||
p.seed = int(parsed_parameters.get("Seed", seed))
|
||||
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
||||
for n, processed_image in enumerate(proc.images):
|
||||
filename = image_path.name
|
||||
filename = image_path.stem
|
||||
infotext = proc.infotext(p, n)
|
||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
||||
|
||||
if n > 0:
|
||||
left, right = os.path.splitext(filename)
|
||||
filename = f"{left}-{n}{right}"
|
||||
filename += f"-{n}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
||||
if processed_image.mode == 'RGBA':
|
||||
processed_image = processed_image.convert("RGB")
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@@ -180,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = mask_blur
|
||||
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
with closing(p):
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
@@ -208,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
@@ -184,8 +184,7 @@ class InterrogateModels:
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = ""
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
@@ -9,6 +10,9 @@ from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules import timer
|
||||
|
||||
timer.startup_timer.record("start")
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
|
||||
@@ -69,10 +73,12 @@ def git_tag():
|
||||
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
from pathlib import Path
|
||||
changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
|
||||
with changelog_md.open(encoding="utf-8") as file:
|
||||
return next((line.strip() for line in file if line.strip()), "<none>")
|
||||
|
||||
changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
|
||||
with open(changelog_md, "r", encoding="utf-8") as file:
|
||||
line = next((line.strip() for line in file if line.strip()), "<none>")
|
||||
line = line.replace("## ", "")
|
||||
return line
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
@@ -142,15 +148,15 @@ def git_clone(url, dir, name, commithash=None):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
@@ -224,6 +230,44 @@ def run_extensions_installers(settings_file):
|
||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
|
||||
|
||||
def requrements_met(requirements_file):
|
||||
"""
|
||||
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
||||
are already installed. Returns True if so, False if not installed or parsing fails.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import packaging.version
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
|
||||
m = re.match(re_requirement, line)
|
||||
if m is None:
|
||||
return False
|
||||
|
||||
package = m.group(1).strip()
|
||||
version_required = (m.group(2) or "").strip()
|
||||
|
||||
if version_required == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
version_installed = importlib.metadata.version(package)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||
@@ -235,11 +279,13 @@ def prepare_environment():
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
@@ -297,6 +343,7 @@ def prepare_environment():
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
@@ -306,7 +353,9 @@ def prepare_environment():
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
|
||||
if not requrements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
@@ -321,6 +370,7 @@ def prepare_environment():
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
|
||||
@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||
to_remain_in_cpu = [
|
||||
(sd_model, 'first_stage_model'),
|
||||
(sd_model, 'depth_model'),
|
||||
(sd_model, 'embedder'),
|
||||
(sd_model, 'model'),
|
||||
(sd_model, 'embedder'),
|
||||
]
|
||||
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
||||
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||
|
||||
if is_sdxl:
|
||||
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||
elif is_sd2:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||
else:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
||||
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
||||
stored = []
|
||||
for obj, field in to_remain_in_cpu:
|
||||
module = getattr(obj, field, None)
|
||||
stored.append(module)
|
||||
setattr(obj, field, None)
|
||||
|
||||
# send the model to GPU.
|
||||
sd_model.to(devices.device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
||||
|
||||
# put modules back. the modules will be in CPU.
|
||||
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
||||
setattr(obj, field, module)
|
||||
|
||||
# register hooks for those the first three models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
if is_sdxl:
|
||||
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||
elif is_sd2:
|
||||
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
if sd_model.embedder:
|
||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
||||
del sd_model.cond_stage_model.transformer
|
||||
if hasattr(sd_model, 'cond_stage_model'):
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
@@ -1,22 +1,45 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import platform
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
|
||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# use check `getattr` and try it for compatibility.
|
||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
def check_for_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
|
||||
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
def torch_mps_gc() -> None:
|
||||
try:
|
||||
from modules.shared import state
|
||||
if state.current_latent is not None:
|
||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||
return
|
||||
from torch.mps import empty_cache
|
||||
empty_cache()
|
||||
except Exception:
|
||||
log.warning("MPS garbage collection failed", exc_info=True)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
url: str,
|
||||
*,
|
||||
model_dir: str,
|
||||
progress: bool = True,
|
||||
file_name: str | None = None,
|
||||
) -> str:
|
||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
file_name = os.path.basename(parts.path)
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
from torch.hub import download_url_to_file
|
||||
download_url_to_file(url, cached_file, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
dl = load_file_from_url(model_url, places[0], True, download_name)
|
||||
output.append(dl)
|
||||
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
||||
else:
|
||||
output.append(model_url)
|
||||
|
||||
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
if "http" in file:
|
||||
if file.startswith("http"):
|
||||
file = urlparse(file).path
|
||||
|
||||
file = os.path.basename(file)
|
||||
|
||||
@@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
def mute_sdxl_imports():
|
||||
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
module = Dummy()
|
||||
module.LPIPS = None
|
||||
sys.modules['taming.modules.losses.lpips'] = module
|
||||
|
||||
module = Dummy()
|
||||
module.StableDataModuleFromConfig = None
|
||||
sys.modules['sgm.data'] = module
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:
|
||||
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
mute_sdxl_imports()
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
@@ -35,20 +53,13 @@ for d, must_exist, what, options in path_dirs:
|
||||
d = os.path.abspath(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
elif "sgm" in options:
|
||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||
|
||||
sys.path.insert(0, d)
|
||||
import sgm # noqa: F401
|
||||
sys.path.pop(0)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
class Prioritize:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.path = None
|
||||
|
||||
def __enter__(self):
|
||||
self.path = sys.path.copy()
|
||||
sys.path = [paths[self.name]] + sys.path
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.path = self.path
|
||||
self.path = None
|
||||
|
||||
@@ -9,8 +9,7 @@ from modules.shared import opts
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
shared.state.begin(job="extras")
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
for image, name in zip(image_data, image_names):
|
||||
shared.state.textinfo = name
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
|
||||
|
||||
@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
|
||||
self.uc = None
|
||||
self.c = None
|
||||
|
||||
self.user = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
@@ -328,8 +330,21 @@ class StableDiffusionProcessing:
|
||||
|
||||
caches is a list with items described above.
|
||||
"""
|
||||
|
||||
cached_params = (
|
||||
required_prompts,
|
||||
steps,
|
||||
opts.CLIP_stop_at_last_layers,
|
||||
shared.sd_model.sd_checkpoint_info,
|
||||
extra_network_data,
|
||||
opts.sdxl_crop_left,
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
)
|
||||
|
||||
for cache in caches:
|
||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
||||
if cache[0] is not None and cached_params == cache[0]:
|
||||
return cache[1]
|
||||
|
||||
cache = caches[0]
|
||||
@@ -337,14 +352,17 @@ class StableDiffusionProcessing:
|
||||
with devices.autocast():
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||
|
||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
||||
cache[0] = cached_params
|
||||
return cache[1]
|
||||
|
||||
def setup_conds(self):
|
||||
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
||||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||
@@ -521,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||
x = model.decode_first_stage(x)
|
||||
x = model.decode_first_stage(x.to(devices.dtype_vae))
|
||||
|
||||
return x
|
||||
|
||||
@@ -549,7 +566,7 @@ def program_version():
|
||||
return res
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
@@ -573,7 +590,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
@@ -585,13 +602,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
"User": p.user if opts.add_user_name_to_info else None,
|
||||
}
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
@@ -602,7 +621,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
try:
|
||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||
p.override_settings.pop('sd_model_checkpoint', None)
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
@@ -663,8 +682,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
@@ -728,9 +747,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
p.setup_conds()
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
comments[comment] = 1
|
||||
for comment in model_hijack.comments:
|
||||
comments[comment] = 1
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
@@ -824,7 +844,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
grid = images.image_grid(output_images, p.batch_size)
|
||||
|
||||
if opts.return_grid:
|
||||
text = infotext()
|
||||
text = infotext(use_main_prompt=True)
|
||||
infotexts.insert(0, text)
|
||||
if opts.enable_pnginfo:
|
||||
grid.info["parameters"] = text
|
||||
@@ -832,7 +852,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
index_of_first_image = 1
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
if not p.disable_extra_networks and p.extra_network_data:
|
||||
extra_networks.deactivate(p, p.extra_network_data)
|
||||
@@ -1074,6 +1094,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
@@ -1280,7 +1303,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
|
||||
image = torch.from_numpy(batch_images)
|
||||
image = 2. * image - 1.
|
||||
image = image.to(shared.device)
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
|
||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
@@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts, steps):
|
||||
class SdConditioning(list):
|
||||
"""
|
||||
A list with prompts for stable diffusion's conditioner model.
|
||||
Can also specify width and height of created image - SDXL needs it.
|
||||
"""
|
||||
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
|
||||
super().__init__()
|
||||
self.extend(prompts)
|
||||
|
||||
if copy_from is None:
|
||||
copy_from = prompts
|
||||
|
||||
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
||||
self.width = width or getattr(copy_from, 'width', None)
|
||||
self.height = height or getattr(copy_from, 'height', None)
|
||||
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
@@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts = [x[1] for x in prompt_schedule]
|
||||
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
if isinstance(conds, dict):
|
||||
cond = {k: v[i] for k, v in conds.items()}
|
||||
else:
|
||||
cond = conds[i]
|
||||
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
@@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
re_AND = re.compile(r"\bAND\b")
|
||||
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||
|
||||
def get_multicond_prompt_list(prompts):
|
||||
|
||||
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||
res_indexes = []
|
||||
|
||||
prompt_flat_list = []
|
||||
prompt_indexes = {}
|
||||
prompt_flat_list = SdConditioning(prompts)
|
||||
prompt_flat_list.clear()
|
||||
|
||||
for prompt in prompts:
|
||||
subprompts = re_AND.split(prompt)
|
||||
@@ -196,6 +223,7 @@ class MulticondLearnedConditioning:
|
||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
@@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
is_dict = isinstance(param, dict)
|
||||
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
if is_dict:
|
||||
for k, param in cond_schedule[target_index].cond.items():
|
||||
res[k][i] = param
|
||||
else:
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def stack_conds(tensors):
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(tensors)
|
||||
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
@@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||
return conds_list, stacked
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.local_data_path):
|
||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
if info.local_data_path.startswith("http"):
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
||||
|
||||
return info
|
||||
except Exception:
|
||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||
return None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
import gradio as gr
|
||||
@@ -116,6 +117,21 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||
"""
|
||||
Calledafter extra networks activation, before conds calculation
|
||||
allow modification of the network after extra networks activation been applied
|
||||
won't be call if p.disable_extra_networks
|
||||
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
- extra_network_data - list of ExtraNetworkParams for current stage
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
@@ -186,6 +202,11 @@ class Script:
|
||||
|
||||
return f'script_{tabname}{title}_{item_id}'
|
||||
|
||||
def before_hr(self, p, *args):
|
||||
"""
|
||||
This function is called before hires fix start.
|
||||
"""
|
||||
pass
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
@@ -249,7 +270,7 @@ def load_scripts():
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
if not inspect.isclass(script_class):
|
||||
continue
|
||||
|
||||
if issubclass(script_class, Script):
|
||||
@@ -483,6 +504,14 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||
|
||||
def after_extra_networks_activate(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
@@ -548,6 +577,15 @@ class ScriptRunner:
|
||||
self.scripts[si].args_to = args_to
|
||||
|
||||
|
||||
def before_hr(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.before_hr(p, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
||||
|
||||
@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
import sgm.modules.diffusionmodules.openaimodel
|
||||
import sgm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
if current_optimizer is not None:
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
@@ -89,6 +97,10 @@ def undo_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
def fix_checkpoint():
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
@@ -147,7 +159,6 @@ def undo_weighted_forward(sd_model):
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
comments = []
|
||||
layers = None
|
||||
circular_enabled = False
|
||||
clip = None
|
||||
@@ -156,6 +167,9 @@ class StableDiffusionModelHijack:
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
def __init__(self):
|
||||
self.extra_generation_params = {}
|
||||
self.comments = []
|
||||
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
@@ -166,6 +180,32 @@ class StableDiffusionModelHijack:
|
||||
undo_optimizations()
|
||||
|
||||
def hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
text_cond_models = []
|
||||
|
||||
for i in range(len(conditioner.embedders)):
|
||||
embedder = conditioner.embedders[i]
|
||||
typename = type(embedder).__name__
|
||||
if typename == 'FrozenOpenCLIPEmbedder':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenCLIPEmbedder':
|
||||
model_embeddings = embedder.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
if len(text_cond_models) == 1:
|
||||
m.cond_stage_model = text_cond_models[0]
|
||||
else:
|
||||
m.cond_stage_model = conditioner
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
@@ -236,6 +276,7 @@ class StableDiffusionModelHijack:
|
||||
|
||||
def clear_comments(self):
|
||||
self.comments = []
|
||||
self.extra_generation_params = {}
|
||||
|
||||
def get_prompt_lengths(self, text):
|
||||
if self.clip is None:
|
||||
|
||||
@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||
self.chunk_length = 75
|
||||
|
||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
self.legacy_ucg_val = None
|
||||
|
||||
def empty_chunk(self):
|
||||
"""creates an empty PromptChunk and returns it"""
|
||||
|
||||
@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
"""
|
||||
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||
An example shape returned by this function can be: (2, 77, 768).
|
||||
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
@@ -229,11 +234,23 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
zs.append(z)
|
||||
|
||||
if len(used_embeddings) > 0:
|
||||
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||
if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
||||
hashes = []
|
||||
for name, embedding in used_embeddings.items():
|
||||
shorthash = embedding.shorthash
|
||||
if not shorthash:
|
||||
continue
|
||||
|
||||
return torch.hstack(zs)
|
||||
name = name.replace(":", "").replace(",", "")
|
||||
hashes.append(f"{name}: {shorthash}")
|
||||
|
||||
if hashes:
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
return torch.hstack(zs)
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
"""
|
||||
@@ -256,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
z *= (original_mean / new_mean)
|
||||
|
||||
return z
|
||||
|
||||
@@ -315,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
|
||||
|
||||
if self.wrapped.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
else:
|
||||
z = outputs.hidden_states[self.wrapped.layer_idx]
|
||||
|
||||
return z
|
||||
|
||||
@@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
d = self.wrapped.encode_with_transformer(tokens)
|
||||
z = d[self.wrapped.layer]
|
||||
|
||||
pooled = d.get("pooled")
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
||||
@@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
class SdOptimization:
|
||||
@@ -39,6 +43,9 @@ class SdOptimization:
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
class SdOptimizationXformers(SdOptimization):
|
||||
name = "xformers"
|
||||
@@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdpNoMem(SdOptimization):
|
||||
@@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
@@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
@@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationV1(SdOptimization):
|
||||
@@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
|
||||
cmd_opt = "opt_split_attention_v1"
|
||||
priority = 10
|
||||
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
||||
class SdOptimizationInvokeAI(SdOptimization):
|
||||
@@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
|
||||
|
||||
class SdOptimizationDoggettx(SdOptimization):
|
||||
@@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
|
||||
def list_optimizers(res):
|
||||
@@ -155,7 +173,7 @@ def get_available_vram():
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
@@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
@@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
@@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
@@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
|
||||
def einsum_op_cuda(q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def einsum_op(q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return einsum_op_cuda(q, k, v)
|
||||
@@ -328,7 +354,8 @@ def einsum_op(q, k, v):
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
@@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||
|
||||
h = self.heads
|
||||
@@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = q.shape
|
||||
@@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
||||
return None
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
@@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
@@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
@@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sdp_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
||||
|
||||
def sdp_no_mem_attnblock_forward(self, x):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return sdp_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sub_quad_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
@@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
if isinstance(cond[y], list):
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
else:
|
||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
|
||||
with devices.autocast():
|
||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||
@@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
|
||||
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
|
||||
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
|
||||
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
@@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
|
||||
checkpoints_list = {}
|
||||
checkpoint_alisases = {}
|
||||
checkpoint_aliases = {}
|
||||
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
@@ -66,7 +67,7 @@ class CheckpointInfo:
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
for id in self.ids:
|
||||
checkpoint_alisases[id] = self
|
||||
checkpoint_aliases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||
@@ -112,7 +113,7 @@ def checkpoint_tiles():
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
checkpoint_aliases.clear()
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
||||
@@ -136,7 +137,7 @@ def list_models():
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
checkpoint_info = checkpoint_alisases.get(search_string, None)
|
||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
@@ -166,7 +167,7 @@ def select_checkpoint():
|
||||
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
|
||||
if not shared.opts.disable_mmap_load_safetensors:
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
@@ -283,6 +289,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
@@ -313,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
@@ -328,7 +341,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
if hasattr(model, 'logvar'):
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
@@ -385,10 +399,11 @@ def repair_config(sd_config):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
@@ -401,6 +416,8 @@ def repair_config(sd_config):
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
|
||||
|
||||
class SdModelData:
|
||||
@@ -435,6 +452,15 @@ class SdModelData:
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
if hasattr(sd_model, 'conditioner'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
return d['crossattn']
|
||||
else:
|
||||
return sd_model.cond_stage_model([""])
|
||||
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@@ -455,7 +481,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
@@ -507,7 +533,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
|
||||
@@ -585,7 +611,6 @@ def unload_model_weights(sd_model=None, info=None):
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
|
||||
@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
return config_unclip
|
||||
|
||||
99
modules/sd_models_xl.py
Normal file
99
modules/sd_models_xl.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
width = getattr(self, 'target_width', 1024)
|
||||
height = getattr(self, 'target_height', 1024)
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
}
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||
return self.model(x, t, cond)
|
||||
|
||||
|
||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
return x
|
||||
|
||||
|
||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
|
||||
|
||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
res = []
|
||||
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
res.append(encoded)
|
||||
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
return embedder.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
# model.cond_stage_model will be set in sd_hijack
|
||||
|
||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||
sgm.modules.encoders.modules.print = lambda *args: None
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
@@ -28,6 +28,9 @@ def create_sampler(name, model):
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
if model.is_sdxl and config.options.get("no_sdxl", False):
|
||||
raise Exception(f"Sampler {config.name} is not supported for SDXL")
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -53,6 +53,28 @@ k_diffusion_scheduler = {
|
||||
}
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module):
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module):
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
from modules import devices, paths, shared
|
||||
|
||||
sd_vae_approx_model = None
|
||||
sd_vae_approx_models = {}
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
@@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading VAEApprox model to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
||||
loaded_model = sd_vae_approx_models.get(model_name)
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
|
||||
|
||||
return sd_vae_approx_model
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||
|
||||
loaded_model = VAEApprox()
|
||||
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_approx_models[model_name] = loaded_model
|
||||
|
||||
return loaded_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
if shared.sd_model.is_sdxl:
|
||||
coeffs = [
|
||||
[ 0.3448, 0.4168, 0.4395],
|
||||
[-0.1953, -0.0290, 0.0250],
|
||||
[ 0.1074, 0.0886, -0.0163],
|
||||
[-0.3730, -0.2499, -0.2088],
|
||||
]
|
||||
else:
|
||||
coeffs = [
|
||||
[ 0.298, 0.207, 0.208],
|
||||
[ 0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]
|
||||
|
||||
coefs = torch.tensor(coeffs).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
|
||||
@@ -8,9 +8,9 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
from modules import devices, paths_internal, shared
|
||||
|
||||
sd_vae_taesd = None
|
||||
sd_vae_taesd_models = {}
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def download_model(model_path):
|
||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
@@ -72,17 +70,19 @@ def download_model(model_path):
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_taesd
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||
download_model(model_path)
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
loaded_model = TAESD(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
||||
return loaded_model.decoder
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
@@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
demo = None
|
||||
|
||||
parser = cmd_args.parser
|
||||
@@ -144,12 +148,15 @@ class State:
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def nextjob(self):
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
@@ -173,7 +180,7 @@ class State:
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
@@ -187,10 +194,13 @@ class State:
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
@@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
@@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
@@ -414,9 +429,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
@@ -451,12 +473,15 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
@@ -470,7 +495,6 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||
@@ -481,6 +505,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
@@ -493,6 +518,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||
@@ -817,8 +843,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
@@ -843,8 +873,11 @@ def walk_files(path, allowed_extensions=None):
|
||||
if allowed_extensions is not None:
|
||||
allowed_extensions = set(allowed_extensions)
|
||||
|
||||
for root, _, files in os.walk(path, followlinks=True):
|
||||
for filename in files:
|
||||
items = list(os.walk(path, followlinks=True))
|
||||
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
||||
|
||||
for root, _, files in items:
|
||||
for filename in sorted(files, key=natural_sort_key):
|
||||
if allowed_extensions is not None:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext not in allowed_extensions:
|
||||
|
||||
@@ -2,11 +2,51 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||
saved_params_shared = {
|
||||
"batch_size",
|
||||
"clip_grad_mode",
|
||||
"clip_grad_value",
|
||||
"create_image_every",
|
||||
"data_root",
|
||||
"gradient_step",
|
||||
"initial_step",
|
||||
"latent_sampling_method",
|
||||
"learn_rate",
|
||||
"log_directory",
|
||||
"model_hash",
|
||||
"model_name",
|
||||
"num_of_dataset_images",
|
||||
"steps",
|
||||
"template_file",
|
||||
"training_height",
|
||||
"training_width",
|
||||
}
|
||||
saved_params_ti = {
|
||||
"embedding_name",
|
||||
"num_vectors_per_token",
|
||||
"save_embedding_every",
|
||||
"save_image_with_stored_embedding",
|
||||
}
|
||||
saved_params_hypernet = {
|
||||
"activation_func",
|
||||
"add_layer_norm",
|
||||
"hypernetwork_name",
|
||||
"layer_structure",
|
||||
"save_hypernetwork_every",
|
||||
"use_dropout",
|
||||
"weight_init",
|
||||
}
|
||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
|
||||
saved_params_previews = {
|
||||
"preview_cfg_scale",
|
||||
"preview_height",
|
||||
"preview_negative_prompt",
|
||||
"preview_prompt",
|
||||
"preview_sampler_index",
|
||||
"preview_seed",
|
||||
"preview_steps",
|
||||
"preview_width",
|
||||
}
|
||||
|
||||
|
||||
def save_settings_to_file(log_directory, all_params):
|
||||
|
||||
@@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from contextlib import closing
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@@ -12,7 +13,7 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -48,6 +49,8 @@ class Embedding:
|
||||
self.sd_checkpoint_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.filename = None
|
||||
self.hash = None
|
||||
self.shorthash = None
|
||||
|
||||
def save(self, filename):
|
||||
embedding_data = {
|
||||
@@ -81,6 +84,10 @@ class Embedding:
|
||||
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||
return self.cached_checksum
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
|
||||
class DirWithTextualInversionEmbeddings:
|
||||
def __init__(self, path):
|
||||
@@ -198,6 +205,7 @@ class EmbeddingDatabase:
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.filename = path
|
||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
@@ -248,7 +256,7 @@ class EmbeddingDatabase:
|
||||
self.word_embeddings.update(sorted_word_embeddings)
|
||||
|
||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if self.skipped_embeddings:
|
||||
@@ -584,8 +592,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
with closing(p):
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from contextlib import closing
|
||||
|
||||
import modules.scripts
|
||||
from modules import sd_samplers, processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
from modules.ui import plaintext_to_html
|
||||
import gradio as gr
|
||||
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
@@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.script_args = args
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||
with closing(p):
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
p.close()
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
@@ -67,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
@@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐
|
||||
up_down_symbol = '\u2195\ufe0f' # ↕️
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
return ui_common.plaintext_to_html(text)
|
||||
plaintext_to_html = ui_common.plaintext_to_html
|
||||
|
||||
|
||||
def send_gradio_gallery_to_image(x):
|
||||
@@ -155,7 +154,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
||||
img = Image.open(image)
|
||||
filename = os.path.basename(image)
|
||||
left, _ = os.path.splitext(filename)
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
|
||||
|
||||
return [gr.update(), None]
|
||||
|
||||
@@ -733,6 +732,10 @@ def create_ui():
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
with gr.Accordion("PNG info", open=False):
|
||||
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
|
||||
@@ -773,7 +776,7 @@ def create_ui():
|
||||
selected_scale_tab = gr.State(value=0)
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
||||
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||
with FormRow():
|
||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
@@ -782,7 +785,7 @@ def create_ui():
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||
|
||||
with gr.Tab(label="Resize by") as tab_scale_by:
|
||||
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||
|
||||
with FormRow():
|
||||
@@ -934,6 +937,9 @@ def create_ui():
|
||||
img2img_batch_output_dir,
|
||||
img2img_batch_inpaint_mask_dir,
|
||||
override_settings,
|
||||
img2img_batch_use_png_info,
|
||||
img2img_batch_png_info_props,
|
||||
img2img_batch_png_info_dir,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
|
||||
@@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index):
|
||||
return html_info, gr.update()
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||
return text
|
||||
def plaintext_to_html(text, classname=None):
|
||||
content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
|
||||
|
||||
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
@@ -157,7 +158,7 @@ Requested path was: {f}
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
import os.path
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -138,7 +138,10 @@ def extension_table():
|
||||
<table id="extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>
|
||||
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
||||
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
||||
</th>
|
||||
<th>URL</th>
|
||||
<th>Branch</th>
|
||||
<th>Version</th>
|
||||
@@ -170,7 +173,7 @@ def extension_table():
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
||||
<td>{remote}</td>
|
||||
<td>{ext.branch}</td>
|
||||
<td>{version_link}</td>
|
||||
@@ -421,9 +424,19 @@ sort_ordering = [
|
||||
(False, lambda x: x.get('name', 'z')),
|
||||
(True, lambda x: x.get('name', 'z')),
|
||||
(False, lambda x: 'z'),
|
||||
(True, lambda x: x.get('commit_time', '')),
|
||||
(True, lambda x: x.get('created_at', '')),
|
||||
(True, lambda x: x.get('stars', 0)),
|
||||
]
|
||||
|
||||
|
||||
def get_date(info: dict, key):
|
||||
try:
|
||||
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
|
||||
except (ValueError, TypeError):
|
||||
return ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
|
||||
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
||||
name = ext.get("name", "noname")
|
||||
stars = int(ext.get("stars", 0))
|
||||
added = ext.get('added', 'unknown')
|
||||
update_time = get_date(ext, 'commit_time')
|
||||
create_time = get_date(ext, 'created_at')
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
extension_tags = ext.get("tags", [])
|
||||
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||
<td>{html.escape(description)}<p class="info">
|
||||
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
|
||||
@@ -496,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
|
||||
|
||||
def preload_extensions_git_metadata():
|
||||
t0 = time.time()
|
||||
for extension in extensions.extensions:
|
||||
extension.read_info_from_repo()
|
||||
print(
|
||||
f"preload_extensions_git_metadata for "
|
||||
f"{len(extensions.extensions)} extensions took "
|
||||
f"{time.time() - t0:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
def create_ui():
|
||||
@@ -553,13 +564,14 @@ def create_ui():
|
||||
with gr.TabItem("Available", id="available"):
|
||||
with gr.Row():
|
||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
|
||||
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
||||
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
|
||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
||||
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
@@ -568,9 +580,9 @@ def create_ui():
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
|
||||
@@ -2,14 +2,16 @@ import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
from modules.ui import up_down_symbol
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
@@ -26,12 +28,15 @@ def register_page(page):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not os.path.isfile(filename):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".jpeg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
@@ -48,25 +53,71 @@ def get_metadata(page: str = "", item: str = ""):
|
||||
if metadata is None:
|
||||
return JSONResponse({})
|
||||
|
||||
return JSONResponse({"metadata": metadata})
|
||||
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
|
||||
|
||||
|
||||
def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||
|
||||
try:
|
||||
item = page.create_item(name, enable_filter=False)
|
||||
page.items[name] = item
|
||||
except Exception as e:
|
||||
errors.display(e, "creating item for extra network")
|
||||
item = page.items.get(name)
|
||||
|
||||
page.read_user_metadata(item)
|
||||
item_html = page.create_html_for_item(item, tabname)
|
||||
|
||||
return JSONResponse({"html": item_html})
|
||||
|
||||
|
||||
def add_pages_to_demo(app):
|
||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
||||
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
|
||||
|
||||
|
||||
def quote_js(s):
|
||||
s = s.replace('\\', '\\\\')
|
||||
s = s.replace('"', '\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.name = title.lower()
|
||||
self.id_page = self.name.replace(" ", "_")
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
self.items = {}
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def read_user_metadata(self, item):
|
||||
filename = item.get("filename", None)
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
|
||||
desc = metadata.get("description", None)
|
||||
if desc is not None:
|
||||
item["description"] = desc
|
||||
|
||||
item["user_metadata"] = metadata
|
||||
|
||||
def link_preview(self, filename):
|
||||
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
||||
mtime = os.path.getmtime(filename)
|
||||
@@ -83,15 +134,14 @@ class ExtraNetworksPage:
|
||||
return ""
|
||||
|
||||
def create_html(self, tabname):
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for root, dirs, _ in os.walk(parentdir, followlinks=True):
|
||||
for dirname in dirs:
|
||||
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
||||
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
||||
x = os.path.join(root, dirname)
|
||||
|
||||
if not os.path.isdir(x):
|
||||
@@ -119,11 +169,15 @@ class ExtraNetworksPage:
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
for item in self.list_items():
|
||||
self.items = {x["name"]: x for x in self.list_items()}
|
||||
for item in self.items.values():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
if "user_metadata" not in item:
|
||||
self.read_user_metadata(item)
|
||||
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
@@ -133,16 +187,19 @@ class ExtraNetworksPage:
|
||||
self_name_id = self.name.replace(" ", "_")
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
|
||||
<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
return res
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -158,7 +215,7 @@ class ExtraNetworksPage:
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
@@ -166,7 +223,9 @@ class ExtraNetworksPage:
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
|
||||
|
||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
|
||||
|
||||
local_path = ""
|
||||
filename = item.get("filename", "")
|
||||
@@ -190,16 +249,17 @@ class ExtraNetworksPage:
|
||||
|
||||
args = {
|
||||
"background_image": background_image,
|
||||
"style": f"'display: none; {height}{width}'",
|
||||
"style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
"tabname": quote_js(tabname),
|
||||
"local_preview": quote_js(item["local_preview"]),
|
||||
"name": item["name"],
|
||||
"description": (item.get("description") or ""),
|
||||
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
"metadata_button": metadata_button,
|
||||
"edit_button": edit_button,
|
||||
"search_only": " search_only" if search_only else "",
|
||||
"sort_keys": sort_keys,
|
||||
}
|
||||
@@ -247,6 +307,9 @@ class ExtraNetworksPage:
|
||||
pass
|
||||
return None
|
||||
|
||||
def create_user_metadata_editor(self, ui, tabname):
|
||||
return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_pages.clear()
|
||||
@@ -297,23 +360,26 @@ def create_ui(container, button, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.pages_contents = []
|
||||
ui.user_metadata_editors = []
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
||||
for page in ui.stored_extra_pages:
|
||||
page_id = page.title.lower().replace(" ", "_")
|
||||
|
||||
with gr.Tab(page.title, id=page_id):
|
||||
elem_id = f"{tabname}_{page_id}_cards_html"
|
||||
with gr.Tab(page.title, id=page.id_page):
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
|
||||
editor = page.create_user_metadata_editor(ui, tabname)
|
||||
editor.create_ui()
|
||||
ui.user_metadata_editors.append(editor)
|
||||
|
||||
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
|
||||
gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
||||
ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
@@ -363,6 +429,8 @@ def path_is_parent(parent_path, child_path):
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
def save_preview(index, images, filename):
|
||||
# this function is here for backwards compatibility and likely will be removed soon
|
||||
|
||||
if len(images) == 0:
|
||||
print("There is no image in gallery to save as a preview.")
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
@@ -394,3 +462,7 @@ def setup_ui(ui, gallery):
|
||||
outputs=[*ui.pages]
|
||||
)
|
||||
|
||||
for editor in ui.user_metadata_editors:
|
||||
editor.setup_ui(gallery)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
from modules.ui_extra_networks import quote_js
|
||||
|
||||
|
||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
@@ -12,21 +12,23 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def list_items(self):
|
||||
checkpoint: sd_models.CheckpointInfo
|
||||
for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
yield {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": path,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
def create_item(self, name, index=None):
|
||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
return {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": checkpoint.filename,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
}
|
||||
|
||||
}
|
||||
def list_items(self):
|
||||
for index, name in enumerate(sd_models.checkpoints_list):
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
from modules.ui_extra_networks import quote_js
|
||||
|
||||
|
||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
@@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
full_path = shared.hypernetworks[name]
|
||||
path, ext = os.path.splitext(full_path)
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"filename": full_path,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||
}
|
||||
|
||||
def list_items(self):
|
||||
for index, (name, path) in enumerate(shared.hypernetworks.items()):
|
||||
path, ext = os.path.splitext(path)
|
||||
|
||||
yield {
|
||||
"name": name,
|
||||
"filename": path,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||
|
||||
}
|
||||
for index, name in enumerate(shared.hypernetworks):
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [shared.cmd_opts.hypernetwork_dir]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from modules import ui_extra_networks, sd_hijack, shared
|
||||
from modules.ui_extra_networks import quote_js
|
||||
|
||||
|
||||
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
@@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
def refresh(self):
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def list_items(self):
|
||||
for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
yield {
|
||||
"name": embedding.name,
|
||||
"filename": embedding.filename,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"prompt": json.dumps(embedding.name),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||
def create_item(self, name, index=None):
|
||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||
|
||||
}
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
return {
|
||||
"name": name,
|
||||
"filename": embedding.filename,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"prompt": quote_js(embedding.name),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||
}
|
||||
|
||||
def list_items(self):
|
||||
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
||||
yield self.create_item(name, index)
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
||||
|
||||
195
modules/ui_extra_networks_user_metadata.py
Normal file
195
modules/ui_extra_networks_user_metadata.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import os.path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import generation_parameters_copypaste, images, sysinfo, errors
|
||||
|
||||
|
||||
class UserMetadataEditor:
|
||||
|
||||
def __init__(self, ui, tabname, page):
|
||||
self.ui = ui
|
||||
self.tabname = tabname
|
||||
self.page = page
|
||||
self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
|
||||
|
||||
self.box = None
|
||||
|
||||
self.edit_name_input = None
|
||||
self.button_edit = None
|
||||
|
||||
self.edit_name = None
|
||||
self.edit_description = None
|
||||
self.edit_notes = None
|
||||
self.html_filedata = None
|
||||
self.html_preview = None
|
||||
self.html_status = None
|
||||
|
||||
self.button_cancel = None
|
||||
self.button_replace_preview = None
|
||||
self.button_save = None
|
||||
|
||||
def get_user_metadata(self, name):
|
||||
item = self.page.items.get(name, {})
|
||||
|
||||
user_metadata = item.get('user_metadata', None)
|
||||
if user_metadata is None:
|
||||
user_metadata = {}
|
||||
item['user_metadata'] = user_metadata
|
||||
|
||||
return user_metadata
|
||||
|
||||
def create_extra_default_items_in_left_column(self):
|
||||
pass
|
||||
|
||||
def create_default_editor_elems(self):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
self.edit_name = gr.HTML(elem_classes="extra-network-name")
|
||||
self.edit_description = gr.Textbox(label="Description", lines=4)
|
||||
self.html_filedata = gr.HTML()
|
||||
|
||||
self.create_extra_default_items_in_left_column()
|
||||
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
self.html_preview = gr.HTML()
|
||||
|
||||
def create_default_buttons(self):
|
||||
|
||||
with gr.Row(elem_classes="edit-user-metadata-buttons"):
|
||||
self.button_cancel = gr.Button('Cancel')
|
||||
self.button_replace_preview = gr.Button('Replace preview', variant='primary')
|
||||
self.button_save = gr.Button('Save', variant='primary')
|
||||
|
||||
self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
|
||||
|
||||
self.button_cancel.click(fn=None, _js="closePopup")
|
||||
|
||||
def get_card_html(self, name):
|
||||
item = self.page.items.get(name, {})
|
||||
|
||||
preview_url = item.get("preview", None)
|
||||
|
||||
if not preview_url:
|
||||
filename, _ = os.path.splitext(item["filename"])
|
||||
preview_url = self.page.find_preview(filename)
|
||||
item["preview"] = preview_url
|
||||
|
||||
if preview_url:
|
||||
preview = f'''
|
||||
<div class='card standalone-card-preview'>
|
||||
<img src="{html.escape(preview_url)}" class="preview">
|
||||
</div>
|
||||
'''
|
||||
else:
|
||||
preview = "<div class='card standalone-card-preview'></div>"
|
||||
|
||||
return preview
|
||||
|
||||
def get_metadata_table(self, name):
|
||||
item = self.page.items.get(name, {})
|
||||
try:
|
||||
filename = item["filename"]
|
||||
|
||||
stats = os.stat(filename)
|
||||
params = [
|
||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||
]
|
||||
|
||||
return params
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading info for {name}")
|
||||
return []
|
||||
|
||||
def put_values_into_components(self, name):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
|
||||
try:
|
||||
params = self.get_metadata_table(name)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading metadata info for {name}")
|
||||
params = []
|
||||
|
||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
||||
|
||||
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
||||
|
||||
def write_user_metadata(self, name, metadata):
|
||||
item = self.page.items.get(name, {})
|
||||
filename = item.get("filename", None)
|
||||
basename, ext = os.path.splitext(filename)
|
||||
|
||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
||||
json.dump(metadata, file)
|
||||
|
||||
def save_user_metadata(self, name, desc, notes):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
user_metadata["description"] = desc
|
||||
user_metadata["notes"] = notes
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
|
||||
def setup_save_handler(self, button, func, components):
|
||||
button\
|
||||
.click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
|
||||
.then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
|
||||
|
||||
def create_editor(self):
|
||||
self.create_default_editor_elems()
|
||||
|
||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||
|
||||
self.create_default_buttons()
|
||||
|
||||
self.button_edit\
|
||||
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
|
||||
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
|
||||
|
||||
def create_ui(self):
|
||||
with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
|
||||
self.box = box
|
||||
|
||||
self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
|
||||
self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
|
||||
|
||||
self.create_editor()
|
||||
|
||||
def save_preview(self, index, gallery, name):
|
||||
if len(gallery) == 0:
|
||||
return self.get_card_html(name), "There is no image in gallery to save as a preview."
|
||||
|
||||
item = self.page.items.get(name, {})
|
||||
|
||||
index = int(index)
|
||||
index = 0 if index < 0 else index
|
||||
index = len(gallery) - 1 if index >= len(gallery) else index
|
||||
|
||||
img_info = gallery[index if index >= 0 else 0]
|
||||
image = generation_parameters_copypaste.image_from_url_text(img_info)
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
|
||||
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
||||
|
||||
return self.get_card_html(name), ''
|
||||
|
||||
def setup_ui(self, gallery):
|
||||
self.button_replace_preview.click(
|
||||
fn=self.save_preview,
|
||||
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
||||
inputs=[self.edit_name_input, gallery, self.edit_name_input],
|
||||
outputs=[self.html_preview, self.html_status]
|
||||
).then(
|
||||
fn=None,
|
||||
_js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
|
||||
inputs=[self.edit_name_input],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -260,13 +260,20 @@ class UiSettings:
|
||||
component = self.component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||
change_handler(
|
||||
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, self.text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
if isinstance(component, gr.Textbox):
|
||||
methods = [component.submit, component.blur]
|
||||
elif hasattr(component, 'release'):
|
||||
methods = [component.release]
|
||||
else:
|
||||
methods = [component.change]
|
||||
|
||||
for method in methods:
|
||||
method(
|
||||
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, self.text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
|
||||
Reference in New Issue
Block a user