mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-02-28 10:44:19 +00:00
Merge branch 'dev' into patch-1
This commit is contained in:
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@@ -14,32 +14,31 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||
from modules.api.models import *
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
|
||||
from modules.sd_vae import vae_dict
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
from typing import Dict, List, Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from contextlib import closing
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
return [script.title().lower() for script in scripts].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
@@ -48,20 +47,23 @@ def validate_sampler_name(name):
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
return reqDict
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
@@ -76,6 +78,8 @@ def encode_pil_to_base64(image):
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
@@ -92,6 +96,7 @@ def encode_pil_to_base64(image):
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
@@ -99,8 +104,7 @@ def api_middleware(app: FastAPI):
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
except:
|
||||
import traceback
|
||||
except Exception:
|
||||
rich_available = False
|
||||
|
||||
@app.middleware("http")
|
||||
@@ -131,11 +135,12 @@ def api_middleware(app: FastAPI):
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
traceback.print_exc()
|
||||
errors.report(message, exc_info=True)
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
|
||||
@app.middleware("http")
|
||||
@@ -157,7 +162,7 @@ def api_middleware(app: FastAPI):
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credentials = dict()
|
||||
self.credentials = {}
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credentials[user] = password
|
||||
@@ -166,36 +171,44 @@ class Api:
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
if shared.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
@@ -219,17 +232,25 @@ class Api:
|
||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||
script = script_runner.selectable_scripts[script_idx]
|
||||
return script, script_idx
|
||||
|
||||
def get_scripts_list(self):
|
||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||
|
||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
||||
def get_scripts_list(self):
|
||||
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||
|
||||
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||
|
||||
def get_script_info(self):
|
||||
res = []
|
||||
|
||||
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
||||
res += [script.api_info for script in script_list if script.api_info is not None]
|
||||
|
||||
return res
|
||||
|
||||
def get_script(self, script_name, script_runner):
|
||||
if script_name is None or script_name == "":
|
||||
return None, None
|
||||
|
||||
|
||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||
return script_runner.scripts[script_idx]
|
||||
|
||||
@@ -261,14 +282,14 @@ class Api:
|
||||
script_args[0] = selectable_idx + 1
|
||||
|
||||
# Now check for always on scripts
|
||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||
if request.alwayson_scripts:
|
||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||
if alwayson_script == None:
|
||||
if alwayson_script is None:
|
||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||
# Selectable script in always on script param check
|
||||
if alwayson_script.alwayson == False:
|
||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
||||
if alwayson_script.alwayson is False:
|
||||
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||
# min between arg length in scriptrunner and arg length in the request
|
||||
@@ -276,7 +297,7 @@ class Api:
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||
script_runner = scripts.scripts_txt2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
@@ -304,25 +325,25 @@ class Api:
|
||||
args.pop('save_images', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts != None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
@@ -360,20 +381,20 @@ class Api:
|
||||
args.pop('save_images', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts != None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@@ -381,9 +402,9 @@ class Api:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
@@ -391,9 +412,9 @@ class Api:
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
image_list = reqDict.pop('imageList', [])
|
||||
@@ -402,15 +423,15 @@ class Api:
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return PNGInfoResponse(info="")
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
if geninfo is None:
|
||||
@@ -418,13 +439,13 @@ class Api:
|
||||
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
return PNGInfoResponse(info=geninfo, items=items)
|
||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
@@ -446,9 +467,9 @@ class Api:
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
@@ -465,7 +486,7 @@ class Api:
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return InterrogateResponse(caption=processed)
|
||||
return models.InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
@@ -497,6 +518,10 @@ class Api:
|
||||
return options
|
||||
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
|
||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
|
||||
@@ -521,9 +546,20 @@ class Api:
|
||||
for upscaler in shared.sd_upscalers
|
||||
]
|
||||
|
||||
def get_latent_upscale_modes(self):
|
||||
return [
|
||||
{
|
||||
"name": upscale_mode,
|
||||
}
|
||||
for upscale_mode in [*(shared.latent_upscale_modes or {})]
|
||||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_sd_vaes(self):
|
||||
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
||||
@@ -566,44 +602,42 @@ class Api:
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return CreateResponse(info=f"create embedding filename: {filename}")
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"create embedding error: {e}")
|
||||
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="preprocess")
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess complete')
|
||||
return models.PreprocessResponse(info='preprocess complete')
|
||||
except KeyError as e:
|
||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except Exception as e:
|
||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info=f"preprocess error: {e}")
|
||||
except FileNotFoundError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info=f'preprocess error: {e}')
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_embedding")
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
@@ -616,15 +650,15 @@ class Api:
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except Exception as msg:
|
||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"train embedding error: {msg}")
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job="train_hypernetwork")
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
@@ -641,14 +675,16 @@ class Api:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except Exception as exc:
|
||||
return models.TrainResponse(info=f"train embedding error: {exc}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
return TrainResponse(info=f"train embedding error: {error}")
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
import os, psutil
|
||||
import os
|
||||
import psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||
@@ -675,11 +711,23 @@ class Api:
|
||||
'events': warnings,
|
||||
}
|
||||
else:
|
||||
cuda = { 'error': 'unavailable' }
|
||||
cuda = {'error': 'unavailable'}
|
||||
except Exception as err:
|
||||
cuda = { 'error': f'{err}' }
|
||||
return MemoryResponse(ram = ram, cuda = cuda)
|
||||
cuda = {'error': f'{err}'}
|
||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
|
||||
|
||||
def kill_webui(self):
|
||||
restart.stop_program()
|
||||
|
||||
def restart_webui(self):
|
||||
if restart.is_restartable():
|
||||
restart.restart_program()
|
||||
return Response(status_code=501)
|
||||
|
||||
def stop_webui(request):
|
||||
shared.state.server_command = "stop"
|
||||
return Response("Stopping.")
|
||||
|
||||
@@ -223,8 +223,9 @@ for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
if _options[key].default is not None:
|
||||
_type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
@@ -240,6 +241,9 @@ class UpscalerItem(BaseModel):
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
scale: Optional[float] = Field(title="Scale")
|
||||
|
||||
class LatentUpscalerModeItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
@@ -248,6 +252,10 @@ class SDModelItem(BaseModel):
|
||||
filename: str = Field(title="Filename")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
|
||||
class SDVaeItem(BaseModel):
|
||||
model_name: str = Field(title="Model Name")
|
||||
filename: str = Field(title="Filename")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
@@ -266,10 +274,6 @@ class PromptStyleItem(BaseModel):
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
@@ -286,6 +290,23 @@ class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
|
||||
class ScriptsList(BaseModel):
|
||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
||||
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||
|
||||
|
||||
class ScriptArg(BaseModel):
|
||||
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
|
||||
|
||||
class ScriptInfo(BaseModel):
|
||||
name: str = Field(default=None, title="Name", description="Script name")
|
||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from functools import wraps
|
||||
import html
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared, progress
|
||||
from modules import shared, progress, errors
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
@@ -20,17 +19,18 @@ def wrap_queued_call(func):
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
@wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
|
||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||
if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
|
||||
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
||||
id_task = args[0]
|
||||
progress.add_task_to_queue(id_task)
|
||||
else:
|
||||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
shared.state.begin()
|
||||
shared.state.begin(job=id_task)
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
@@ -47,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
@wraps(func)
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
@@ -56,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {args} {kwargs}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
# When printing out our debug argument list,
|
||||
# do not print out more than a 100 KB of text
|
||||
max_debug_str_len = 131072
|
||||
message = "Error completing request"
|
||||
arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
|
||||
if len(arg_str) > max_debug_str_len:
|
||||
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
|
||||
errors.report(f"{message}\n{arg_str}", exc_info=True)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
@@ -108,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -10,9 +11,9 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
|
||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
|
||||
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
|
||||
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
||||
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
@@ -51,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
@@ -75,6 +77,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
@@ -102,4 +105,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
from modules.codeformer.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
return out, logits, lq_feat
|
||||
|
||||
@@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
@@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
@@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
@@ -317,29 +315,29 @@ class Generator(nn.Module):
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.attn_resolutions = attn_resolutions or [16]
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
@@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
@@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
return self.main(x)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader
|
||||
from modules import shared, devices, modelloader, errors
|
||||
from modules.paths import models_path
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
@@ -17,14 +15,11 @@ model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
have_codeformer = False
|
||||
codeformer = None
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
path = modules.paths.paths.get("CodeFormer", None)
|
||||
if path is None:
|
||||
@@ -33,11 +28,9 @@ def setup_model(dirname):
|
||||
try:
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules.codeformer.codeformer_arch import CodeFormer
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
net_class = CodeFormer
|
||||
|
||||
@@ -96,7 +89,7 @@ def setup_model(dirname):
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
@@ -107,8 +100,8 @@ def setup_model(dirname):
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
||||
except Exception:
|
||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
@@ -129,15 +122,11 @@ def setup_model(dirname):
|
||||
|
||||
return restored_img
|
||||
|
||||
global have_codeformer
|
||||
have_codeformer = True
|
||||
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
except Exception:
|
||||
print("Error setting up CodeFormer:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||
|
||||
# sys.path = stored_sys_path
|
||||
|
||||
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import json
|
||||
import time
|
||||
import tqdm
|
||||
@@ -13,8 +11,8 @@ from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import git
|
||||
|
||||
from modules import shared, extensions
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
|
||||
from modules import shared, extensions, errors
|
||||
from modules.paths_internal import script_path, config_states_dir
|
||||
|
||||
|
||||
all_config_states = OrderedDict()
|
||||
@@ -35,7 +33,7 @@ def list_config_states():
|
||||
j["filepath"] = path
|
||||
config_states.append(j)
|
||||
|
||||
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
|
||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||
|
||||
for cs in config_states:
|
||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||
@@ -53,8 +51,7 @@ def get_webui_config():
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||
|
||||
webui_remote = None
|
||||
webui_commit_hash = None
|
||||
@@ -83,6 +80,8 @@ def get_extension_config():
|
||||
ext_config = {}
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext.read_info_from_repo()
|
||||
|
||||
entry = {
|
||||
"name": ext.name,
|
||||
"path": ext.path,
|
||||
@@ -132,8 +131,7 @@ def restore_webui_config(config):
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -141,8 +139,7 @@ def restore_webui_config(config):
|
||||
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||
except Exception:
|
||||
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
||||
|
||||
|
||||
def restore_extension_config(config):
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
@@ -79,7 +78,7 @@ class DeepDanbooru:
|
||||
|
||||
res = []
|
||||
|
||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import sys
|
||||
import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors
|
||||
|
||||
@@ -13,13 +15,6 @@ def has_mps() -> bool:
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
@@ -65,7 +60,7 @@ def enable_tf32():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
||||
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -154,3 +149,19 @@ def test_for_nans(x, where):
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def first_time_calculation():
|
||||
"""
|
||||
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
||||
spends about 2.7 seconds doing that, at least wih NVidia.
|
||||
"""
|
||||
|
||||
x = torch.zeros((1, 1)).to(device, dtype)
|
||||
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
||||
linear(x)
|
||||
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
@@ -1,8 +1,42 @@
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
||||
|
||||
exception_records = []
|
||||
|
||||
|
||||
def record_exception():
|
||||
_, e, tb = sys.exc_info()
|
||||
if e is None:
|
||||
return
|
||||
|
||||
if exception_records and exception_records[-1] == e:
|
||||
return
|
||||
|
||||
exception_records.append((e, tb))
|
||||
|
||||
if len(exception_records) > 5:
|
||||
exception_records.pop(0)
|
||||
|
||||
|
||||
def report(message: str, *, exc_info: bool = False) -> None:
|
||||
"""
|
||||
Print an error message to stderr, with optional traceback.
|
||||
"""
|
||||
|
||||
record_exception()
|
||||
|
||||
for line in message.splitlines():
|
||||
print("***", line, file=sys.stderr)
|
||||
if exc_info:
|
||||
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
||||
print("---", file=sys.stderr)
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
record_exception()
|
||||
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
@@ -12,9 +46,15 @@ def print_error_explanation(message):
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task):
|
||||
def display(e: Exception, task, *, full_traceback=False):
|
||||
record_exception()
|
||||
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
te = traceback.TracebackException.from_exception(e)
|
||||
if full_traceback:
|
||||
# include frames leading up to the try-catch block
|
||||
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
||||
print(*te.format(), sep="", file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
@@ -28,6 +68,8 @@ already_displayed = {}
|
||||
|
||||
|
||||
def display_once(e: Exception, task):
|
||||
record_exception()
|
||||
|
||||
if task in already_displayed:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import shared, modelloader, images, devices
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules import modelloader, images, devices
|
||||
from modules.shared import opts
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if 'conv_first.weight' in state_dict:
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
@@ -52,9 +48,7 @@ def resrgan2normal(state_dict, nb=23):
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
@@ -138,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
if file.startswith("http"):
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
@@ -147,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(
|
||||
if path.startswith("http"):
|
||||
# TODO: this doesn't use `path` at all?
|
||||
filename = modelloader.load_file_from_url(
|
||||
url=self.model_url,
|
||||
model_dir=self.model_path,
|
||||
model_dir=self.model_download_path,
|
||||
file_name=f"{self.model_name}.pth",
|
||||
progress=True,
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"Unable to load {self.model_path} from {filename}")
|
||||
return None
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
@@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
@@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
from torchvision.ops import DeformConv2d # not tested
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import threading
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
import git
|
||||
|
||||
from modules import shared
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
|
||||
from modules import shared, errors
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
|
||||
extensions = []
|
||||
|
||||
if not os.path.exists(extensions_dir):
|
||||
os.makedirs(extensions_dir)
|
||||
os.makedirs(extensions_dir, exist_ok=True)
|
||||
|
||||
|
||||
def active():
|
||||
@@ -25,6 +20,8 @@ def active():
|
||||
|
||||
|
||||
class Extension:
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
@@ -43,15 +40,19 @@ class Extension:
|
||||
if self.is_builtin or self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.have_info_from_repo = True
|
||||
with self.lock:
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.do_read_info_from_repo()
|
||||
|
||||
def do_read_info_from_repo(self):
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(self.path, ".git")):
|
||||
repo = git.Repo(self.path)
|
||||
repo = Repo(self.path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
@@ -59,18 +60,19 @@ class Extension:
|
||||
try:
|
||||
self.status = 'unknown'
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
head = repo.head.commit
|
||||
self.commit_date = repo.head.commit.committed_date
|
||||
ts = time.asctime(time.gmtime(self.commit_date))
|
||||
commit = repo.head.commit
|
||||
self.commit_date = commit.committed_date
|
||||
if repo.active_branch:
|
||||
self.branch = repo.active_branch.name
|
||||
self.commit_hash = head.hexsha
|
||||
self.version = f'{self.commit_hash[:8]} ({ts})'
|
||||
self.commit_hash = commit.hexsha
|
||||
self.version = self.commit_hash[:8]
|
||||
|
||||
except Exception as ex:
|
||||
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
||||
except Exception:
|
||||
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
||||
self.remote = None
|
||||
|
||||
self.have_info_from_repo = True
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
@@ -87,7 +89,7 @@ class Extension:
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
repo = Repo(self.path)
|
||||
for fetch in repo.remote().fetch(dry_run=True):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
@@ -109,7 +111,7 @@ class Extension:
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self, commit='origin'):
|
||||
repo = git.Repo(self.path)
|
||||
repo = Repo(self.path)
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch(all=True)
|
||||
|
||||
@@ -14,9 +14,26 @@ def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
def register_default_extra_networks():
|
||||
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
||||
register_extra_network(ExtraNetworkHypernet())
|
||||
|
||||
|
||||
class ExtraNetworkParams:
|
||||
def __init__(self, items=None):
|
||||
self.items = items or []
|
||||
self.positional = []
|
||||
self.named = {}
|
||||
|
||||
for item in self.items:
|
||||
parts = item.split('=', 2) if isinstance(item, str) else [item]
|
||||
if len(parts) == 2:
|
||||
self.named[parts[0]] = parts[1]
|
||||
else:
|
||||
self.positional.append(item)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.items == other.items
|
||||
|
||||
|
||||
class ExtraNetwork:
|
||||
@@ -86,12 +103,15 @@ def activate(p, extra_network_data):
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
for extra_network_name in extra_network_data:
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from modules import extra_networks, shared, extra_networks
|
||||
from modules import extra_networks, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
|
||||
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert len(params.items) > 0
|
||||
assert params.items
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
@@ -136,14 +135,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
shared.state.textinfo = "Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
theta_1 = None
|
||||
|
||||
if theta_func1:
|
||||
shared.state.textinfo = f"Loading C"
|
||||
shared.state.textinfo = "Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -199,7 +198,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
@@ -242,9 +241,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
|
||||
metadata = None
|
||||
|
||||
if save_metadata:
|
||||
metadata = {"format": "pt"}
|
||||
|
||||
merge_recipe = {
|
||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||
"primary_model_hash": primary_model_info.sha256,
|
||||
@@ -262,15 +263,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
}
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
|
||||
sd_merge_models = {}
|
||||
|
||||
def add_model_metadata(checkpoint_info):
|
||||
checkpoint_info.calculate_shorthash()
|
||||
metadata["sd_merge_models"][checkpoint_info.sha256] = {
|
||||
sd_merge_models[checkpoint_info.sha256] = {
|
||||
"name": checkpoint_info.name,
|
||||
"legacy_hash": checkpoint_info.hash,
|
||||
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||
}
|
||||
|
||||
metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||
|
||||
add_model_metadata(primary_model_info)
|
||||
if secondary_model_info:
|
||||
@@ -278,7 +281,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if tertiary_model_info:
|
||||
add_model_metadata(tertiary_model_info)
|
||||
|
||||
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
|
||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import base64
|
||||
import html
|
||||
import io
|
||||
import math
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
@@ -23,14 +20,14 @@ registered_param_bindings = []
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
self.paste_field_names = paste_field_names
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
def reset():
|
||||
@@ -38,20 +35,27 @@ def reset():
|
||||
|
||||
|
||||
def quote(text):
|
||||
if ',' not in str(text):
|
||||
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
|
||||
return text
|
||||
|
||||
text = str(text)
|
||||
text = text.replace('\\', '\\\\')
|
||||
text = text.replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
return json.dumps(text, ensure_ascii=False)
|
||||
|
||||
|
||||
def unquote(text):
|
||||
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
||||
return text
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except Exception:
|
||||
return text
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if filedata is None:
|
||||
return None
|
||||
|
||||
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
filedata = filedata[0]
|
||||
|
||||
if type(filedata) == dict and filedata.get("is_file", False):
|
||||
@@ -170,31 +174,6 @@ def send_image_and_dimensions(x):
|
||||
return img, w, h
|
||||
|
||||
|
||||
|
||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||
|
||||
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
||||
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
||||
|
||||
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
||||
"""
|
||||
hypernet_name = hypernet_name.lower()
|
||||
if hypernet_hash is not None:
|
||||
# Try to match the hash in the name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
result = re_hypernet_hash.search(hypernet_key)
|
||||
if result is not None and result[1] == hypernet_hash:
|
||||
return hypernet_key
|
||||
else:
|
||||
# Fall back to a hypernet with the same name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
if hypernet_key.lower().startswith(hypernet_name):
|
||||
return hypernet_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def restore_old_hires_fix_params(res):
|
||||
"""for infotexts that specify old First pass size parameter, convert it into
|
||||
width, height, and hr scale"""
|
||||
@@ -251,28 +230,40 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("Negative prompt:"):
|
||||
done_with_prompt = True
|
||||
line = line[16:].strip()
|
||||
|
||||
if done_with_prompt:
|
||||
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
||||
else:
|
||||
prompt += ("" if prompt == "" else "\n") + line
|
||||
|
||||
if shared.opts.infotext_styles != "Ignore":
|
||||
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
||||
|
||||
if shared.opts.infotext_styles == "Apply":
|
||||
res["Styles array"] = found_styles
|
||||
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
|
||||
res["Styles array"] = found_styles
|
||||
|
||||
res["Prompt"] = prompt
|
||||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[f"{k}-1"] = m.group(1)
|
||||
res[f"{k}-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
try:
|
||||
if v[0] == '"' and v[-1] == '"':
|
||||
v = unquote(v)
|
||||
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[f"{k}-1"] = m.group(1)
|
||||
res[f"{k}-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
except Exception:
|
||||
print(f"Error parsing \"{k}: {v}\"")
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
@@ -286,24 +277,45 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
if "Hires sampler" not in res:
|
||||
res["Hires sampler"] = "Use same sampler"
|
||||
|
||||
if "Hires prompt" not in res:
|
||||
res["Hires prompt"] = ""
|
||||
|
||||
if "Hires negative prompt" not in res:
|
||||
res["Hires negative prompt"] = ""
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
if "RNG" not in res:
|
||||
res["RNG"] = "GPU"
|
||||
|
||||
if "Schedule type" not in res:
|
||||
res["Schedule type"] = "Automatic"
|
||||
|
||||
if "Schedule max sigma" not in res:
|
||||
res["Schedule max sigma"] = 0
|
||||
|
||||
if "Schedule min sigma" not in res:
|
||||
res["Schedule min sigma"] = 0
|
||||
|
||||
if "Schedule rho" not in res:
|
||||
res["Schedule rho"] = 0
|
||||
|
||||
return res
|
||||
|
||||
|
||||
settings_map = {}
|
||||
|
||||
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
('Model hash', 'sd_model_checkpoint'),
|
||||
('ENSD', 'eta_noise_seed_delta'),
|
||||
('Schedule type', 'k_sched_type'),
|
||||
('Schedule max sigma', 'sigma_max'),
|
||||
('Schedule min sigma', 'sigma_min'),
|
||||
('Schedule rho', 'rho'),
|
||||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
@@ -312,8 +324,11 @@ infotext_to_setting_name_mapping = [
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
('Token merging ratio', 'token_merging_ratio'),
|
||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||
('RNG', 'randn_source'),
|
||||
('NGMS', 's_min_uncond'),
|
||||
('Pad conds', 'pad_cond_uncond'),
|
||||
]
|
||||
|
||||
|
||||
@@ -405,7 +420,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
|
||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
||||
|
||||
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
||||
|
||||
@@ -422,5 +437,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
outputs=[],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import facexlib
|
||||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import paths, shared, devices, modelloader
|
||||
from modules import paths, shared, devices, modelloader, errors
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
@@ -27,7 +25,7 @@ def gfpgann():
|
||||
return None
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||
if len(models) == 1 and "http" in models[0]:
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
@@ -72,13 +70,10 @@ gfpgan_constructor = None
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
try:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
from gfpgan import GFPGANer
|
||||
from facexlib import detection, parsing
|
||||
from facexlib import detection, parsing # noqa: F401
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
@@ -112,5 +107,4 @@ def setup_model(dirname):
|
||||
|
||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||
except Exception:
|
||||
print("Error setting up GFPGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||
|
||||
42
modules/gitpython_hack.py
Normal file
42
modules/gitpython_hack.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
|
||||
import git
|
||||
|
||||
|
||||
class Git(git.Git):
|
||||
"""
|
||||
Git subclassed to never use persistent processes.
|
||||
"""
|
||||
|
||||
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
||||
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
||||
|
||||
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
||||
ret = subprocess.check_output(
|
||||
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
||||
input=self._prepare_ref(ref),
|
||||
cwd=self._working_dir,
|
||||
timeout=2,
|
||||
)
|
||||
return self._parse_object_header(ret)
|
||||
|
||||
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
||||
# Not really streaming, per se; this buffers the entire object in memory.
|
||||
# Shouldn't be a problem for our use case, since we're only using this for
|
||||
# object headers (commit objects).
|
||||
ret = subprocess.check_output(
|
||||
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
||||
input=self._prepare_ref(ref),
|
||||
cwd=self._working_dir,
|
||||
timeout=30,
|
||||
)
|
||||
bio = io.BytesIO(ret)
|
||||
hexsha, typename, size = self._parse_object_header(bio.readline())
|
||||
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
||||
|
||||
|
||||
class Repo(git.Repo):
|
||||
GitCommandWrapperType = Git
|
||||
@@ -46,8 +46,8 @@ def calculate_sha256(filename):
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def sha256_from_cache(filename, title):
|
||||
hashes = cache("hashes")
|
||||
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
|
||||
if title not in hashes:
|
||||
@@ -62,10 +62,10 @@ def sha256_from_cache(filename, title):
|
||||
return cached_sha256
|
||||
|
||||
|
||||
def sha256(filename, title):
|
||||
hashes = cache("hashes")
|
||||
def sha256(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
|
||||
sha256_value = sha256_from_cache(filename, title)
|
||||
sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
|
||||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
@@ -73,7 +73,11 @@ def sha256(filename, title):
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
if use_addnet_hash:
|
||||
with open(filename, "rb") as file:
|
||||
sha256_value = addnet_hash_safetensors(file)
|
||||
else:
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
||||
hashes[title] = {
|
||||
@@ -86,6 +90,19 @@ def sha256(filename, title):
|
||||
return sha256_value
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
@@ -12,13 +9,13 @@ import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections import deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
@@ -178,34 +175,34 @@ class Hypernetwork:
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train(self, mode=True):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.train(mode=mode)
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
@@ -326,17 +323,14 @@ def load_hypernetwork(name):
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
hypernetwork = Hypernetwork()
|
||||
|
||||
try:
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
return hypernetwork
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error loading hypernetwork {path}", exc_info=True)
|
||||
return None
|
||||
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
@@ -359,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
if not search:
|
||||
return None
|
||||
search = search.lower()
|
||||
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||
if not applicable:
|
||||
return None
|
||||
applicable = sorted(applicable, key=lambda name: len(name))
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
@@ -404,7 +387,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
@@ -452,18 +435,6 @@ def statistics(data):
|
||||
return total_information, recent_information
|
||||
|
||||
|
||||
def report_statistics(loss_info:dict):
|
||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||
for key in keys:
|
||||
try:
|
||||
print("Loss statistics for file " + key)
|
||||
info, recent = statistics(list(loss_info[key]))
|
||||
print(info)
|
||||
print(recent)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
@@ -541,7 +512,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
@@ -594,7 +565,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
print(e)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
@@ -620,7 +591,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
for _ in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
@@ -637,7 +608,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if use_weight:
|
||||
@@ -658,14 +629,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
loss_logging.append(_loss_step)
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
@@ -675,7 +646,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
_loss_step = 0
|
||||
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
|
||||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
@@ -771,12 +742,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Exception in training hypernetwork", exc_info=True)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval()
|
||||
#report_statistics(loss_dict)
|
||||
sd_hijack_checkpoint.remove()
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import html
|
||||
import os
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytz
|
||||
import io
|
||||
@@ -12,18 +12,27 @@ import re
|
||||
import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from fonts.ttf import Roboto
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks, errors
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.paths_internal import roboto_ttf_file
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
def get_font(fontsize: int):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||
|
||||
|
||||
def image_grid(imgs, batch_size=1, rows=None):
|
||||
if rows is None:
|
||||
if opts.n_rows > 0:
|
||||
@@ -132,6 +141,11 @@ class GridAnnotation:
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
||||
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
||||
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
||||
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
@@ -142,14 +156,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def get_font(fontsize):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(Roboto, fontsize)
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
for i, line in enumerate(lines):
|
||||
for line in lines:
|
||||
fnt = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
@@ -167,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
fnt = get_font(fontsize)
|
||||
|
||||
color_active = (0, 0, 0)
|
||||
color_inactive = (153, 153, 153)
|
||||
|
||||
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
||||
|
||||
cols = im.width // width
|
||||
@@ -178,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
||||
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
||||
|
||||
calc_img = Image.new("RGB", (1, 1), "white")
|
||||
calc_img = Image.new("RGB", (1, 1), color_background)
|
||||
calc_d = ImageDraw.Draw(calc_img)
|
||||
|
||||
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
||||
@@ -199,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
@@ -335,8 +340,20 @@ def sanitize_filename_part(text, replace_spaces=True):
|
||||
|
||||
|
||||
class FilenameGenerator:
|
||||
def get_vae_filename(self): #get the name of the VAE file.
|
||||
if sd_vae.loaded_vae_file is None:
|
||||
return "NoneType"
|
||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||
split_file_name = file_name.split('.')
|
||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||
else:
|
||||
return split_file_name[0]
|
||||
|
||||
replacements = {
|
||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
|
||||
'steps': lambda self: self.p and self.p.steps,
|
||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||
'width': lambda self: self.image.width,
|
||||
@@ -353,20 +370,24 @@ class FilenameGenerator:
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
'prompt_words': lambda self: self.prompt_words(),
|
||||
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
|
||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
|
||||
'batch_size': lambda self: self.p.batch_size,
|
||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||
'user': lambda self: self.p.user,
|
||||
'vae_filename': lambda self: self.get_vae_filename(),
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
def __init__(self, p, seed, prompt, image):
|
||||
def __init__(self, p, seed, prompt, image, zip=False):
|
||||
self.p = p
|
||||
self.seed = seed
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
|
||||
self.zip = zip
|
||||
|
||||
def hasprompt(self, *args):
|
||||
lower = self.prompt.lower()
|
||||
if self.p is None or self.prompt is None:
|
||||
@@ -389,7 +410,7 @@ class FilenameGenerator:
|
||||
|
||||
prompt_no_style = self.prompt
|
||||
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
||||
if len(style) > 0:
|
||||
if style:
|
||||
for part in style.split("{prompt}"):
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
|
||||
@@ -398,7 +419,7 @@ class FilenameGenerator:
|
||||
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
||||
|
||||
def prompt_words(self):
|
||||
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
|
||||
words = [x for x in re_nonletters.split(self.prompt or "") if x]
|
||||
if len(words) == 0:
|
||||
words = ["empty"]
|
||||
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
||||
@@ -406,16 +427,16 @@ class FilenameGenerator:
|
||||
def datetime(self, *args):
|
||||
time_datetime = datetime.datetime.now()
|
||||
|
||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||
time_format = args[0] if (args and args[0] != "") else self.default_time_format
|
||||
try:
|
||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
time_zone = None
|
||||
|
||||
time_zone_time = time_datetime.astimezone(time_zone)
|
||||
try:
|
||||
formatted_time = time_zone_time.strftime(time_format)
|
||||
except (ValueError, TypeError) as _:
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||
|
||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||
@@ -445,8 +466,7 @@ class FilenameGenerator:
|
||||
replacement = fun(self, *pattern_args)
|
||||
except Exception:
|
||||
replacement = None
|
||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
|
||||
|
||||
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||
continue
|
||||
@@ -472,15 +492,61 @@ def get_next_sequence_number(path, basename):
|
||||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(l[0]), result)
|
||||
result = max(int(parts[0]), result)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
||||
"""
|
||||
Saves image to filename, including geninfo as text information for generation info.
|
||||
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
||||
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
||||
"""
|
||||
|
||||
if extension is None:
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
if extension.lower() == '.png':
|
||||
existing_pnginfo = existing_pnginfo or {}
|
||||
if opts.enable_pnginfo:
|
||||
existing_pnginfo[pnginfo_section_name] = geninfo
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in (existing_pnginfo or {}).items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
else:
|
||||
pnginfo_data = None
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert("RGB")
|
||||
elif image.mode == 'I;16':
|
||||
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
if opts.enable_pnginfo and geninfo is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, filename)
|
||||
else:
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||
"""Save an image.
|
||||
|
||||
@@ -565,38 +631,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
|
||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
"""
|
||||
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
"""
|
||||
temp_file_path = f"{filename_without_extension}.tmp"
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
if opts.enable_pnginfo:
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
elif image_to_save.mode == 'I;16':
|
||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, temp_file_path)
|
||||
else:
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
# atomically rename the file with correct extension
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
@@ -612,12 +653,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
||||
ratio = image.width / image.height
|
||||
|
||||
resize_to = None
|
||||
if oversize and ratio > 1:
|
||||
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
|
||||
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
||||
elif oversize:
|
||||
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
|
||||
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
||||
|
||||
if resize_to is not None:
|
||||
try:
|
||||
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
||||
image = image.resize(resize_to, LANCZOS)
|
||||
except Exception:
|
||||
image = image.resize(resize_to)
|
||||
try:
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
except Exception as e:
|
||||
@@ -635,8 +682,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
def read_info_from_image(image):
|
||||
items = image.info or {}
|
||||
IGNORED_INFO_KEYS = {
|
||||
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||
'icc_profile', 'chromaticity', 'photoshop',
|
||||
}
|
||||
|
||||
|
||||
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
items = (image.info or {}).copy()
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
@@ -652,9 +706,8 @@ def read_info_from_image(image):
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
items.pop(field, None)
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
@@ -665,8 +718,7 @@ def read_info_from_image(image):
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules import sd_samplers, images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.images as images
|
||||
import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = []
|
||||
@@ -31,9 +29,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
is_inpaint_batch = len(inpaint_masks) > 0
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
is_inpaint_batch = bool(inpaint_masks)
|
||||
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
@@ -44,6 +43,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
prompt = p.prompt
|
||||
negative_prompt = p.negative_prompt
|
||||
seed = p.seed
|
||||
cfg_scale = p.cfg_scale
|
||||
sampler_name = p.sampler_name
|
||||
steps = p.steps
|
||||
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
@@ -59,23 +66,59 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
continue
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
|
||||
if to_scale:
|
||||
p.width = int(img.width * scale_by)
|
||||
p.height = int(img.height * scale_by)
|
||||
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
image_path = Path(image)
|
||||
if is_inpaint_batch:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# if not found use first one ("same mask for all images" use-case)
|
||||
if not mask_image_path in inpaint_masks:
|
||||
if len(inpaint_masks) == 1:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
else:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_dir = Path(inpaint_mask_dir)
|
||||
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
|
||||
|
||||
if len(masks_found) == 0:
|
||||
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
|
||||
continue
|
||||
|
||||
# it should contain only 1 matching mask
|
||||
# otherwise user has many masks with the same name but different extensions
|
||||
mask_image_path = masks_found[0]
|
||||
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
if use_png_info:
|
||||
try:
|
||||
info_img = img
|
||||
if png_info_dir:
|
||||
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||
info_img = Image.open(info_img_path)
|
||||
geninfo, _ = imgutil.read_info_from_image(info_img)
|
||||
parsed_parameters = parse_generation_parameters(geninfo)
|
||||
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||
except Exception:
|
||||
parsed_parameters = {}
|
||||
|
||||
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
||||
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
||||
p.seed = int(parsed_parameters.get("Seed", seed))
|
||||
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
||||
for n, processed_image in enumerate(proc.images):
|
||||
filename = os.path.basename(image)
|
||||
filename = image_path.name
|
||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
||||
|
||||
if n > 0:
|
||||
@@ -89,7 +132,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
processed_image.save(os.path.join(output_dir, relpath, filename))
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@@ -103,7 +146,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
elif mode == 2: # inpaint
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
||||
image = image.convert("RGB")
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = inpaint_color_sketch
|
||||
@@ -125,7 +169,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
if selected_scale_tab == 1:
|
||||
if selected_scale_tab == 1 and not is_batch:
|
||||
assert image, "Can't scale by because no image is selected"
|
||||
|
||||
width = int(image.width * scale_by)
|
||||
@@ -171,6 +215,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
@@ -180,7 +226,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
|
||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import re
|
||||
@@ -11,7 +10,6 @@ import torch.hub
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
|
||||
blip_image_eval_size = 384
|
||||
@@ -160,7 +158,7 @@ class InterrogateModels:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
||||
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
@@ -186,8 +184,7 @@ class InterrogateModels:
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = ""
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
@@ -208,8 +205,8 @@ class InterrogateModels:
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
for name, topn, items in self.categories():
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for cat in self.categories():
|
||||
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
res += f", ({match}:{score/100:.3f})"
|
||||
@@ -217,8 +214,7 @@ class InterrogateModels:
|
||||
res += f", {match}"
|
||||
|
||||
except Exception:
|
||||
print("Error interrogating", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error interrogating", exc_info=True)
|
||||
res += "<error>"
|
||||
|
||||
self.unload()
|
||||
|
||||
344
modules/launch_utils.py
Normal file
344
modules/launch_utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import importlib.util
|
||||
import platform
|
||||
import json
|
||||
from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
dir_repos = "repositories"
|
||||
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def commit_hash():
|
||||
try:
|
||||
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def git_tag():
|
||||
try:
|
||||
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
from pathlib import Path
|
||||
changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
|
||||
with changelog_md.open(encoding="utf-8") as file:
|
||||
return next((line.strip() for line in file if line.strip()), "<none>")
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
run_kwargs = {
|
||||
"args": command,
|
||||
"shell": True,
|
||||
"env": os.environ if custom_env is None else custom_env,
|
||||
"encoding": 'utf8',
|
||||
"errors": 'ignore',
|
||||
}
|
||||
|
||||
if not live:
|
||||
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||
|
||||
result = subprocess.run(**run_kwargs)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_bits = [
|
||||
f"{errdesc or 'Error running command'}.",
|
||||
f"Command: {command}",
|
||||
f"Error code: {result.returncode}",
|
||||
]
|
||||
if result.stdout:
|
||||
error_bits.append(f"stdout: {result.stdout}")
|
||||
if result.stderr:
|
||||
error_bits.append(f"stderr: {result.stderr}")
|
||||
raise RuntimeError("\n".join(error_bits))
|
||||
|
||||
return (result.stdout or "")
|
||||
|
||||
|
||||
def is_installed(package):
|
||||
try:
|
||||
spec = importlib.util.find_spec(package)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
return spec is not None
|
||||
|
||||
|
||||
def repo_dir(name):
|
||||
return os.path.join(script_path, dir_repos, name)
|
||||
|
||||
|
||||
def run_pip(command, desc=None, live=default_command_live):
|
||||
if args.skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def check_run_python(code: str) -> bool:
|
||||
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
if os.path.exists(dir):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def git_pull_recursive(dir):
|
||||
for subdir, _, _ in os.walk(dir):
|
||||
if os.path.exists(os.path.join(subdir, '.git')):
|
||||
try:
|
||||
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
try:
|
||||
import requests
|
||||
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||
print("--------------------------------------------------------")
|
||||
print("| You are not up to date with the most recent release. |")
|
||||
print("| Consider running `git pull` to update. |")
|
||||
print("--------------------------------------------------------")
|
||||
elif commits['commit']['sha'] == commit:
|
||||
print("You are up to date with the most recent release.")
|
||||
else:
|
||||
print("Not a git clone, can't perform version check.")
|
||||
except Exception as e:
|
||||
print("version check failed", e)
|
||||
|
||||
|
||||
def run_extension_installer(extension_dir):
|
||||
path_installer = os.path.join(extension_dir, "install.py")
|
||||
if not os.path.isfile(path_installer):
|
||||
return
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.path.abspath(".")
|
||||
|
||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
||||
except Exception as e:
|
||||
errors.report(str(e))
|
||||
|
||||
|
||||
def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
if os.path.isfile(settings_file):
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except Exception:
|
||||
errors.report("Could not load settings", exc_info=True)
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none':
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Torch is not able to use GPU; '
|
||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
|
||||
if not is_installed("gfpgan"):
|
||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
if not is_installed("xformers"):
|
||||
exit(0)
|
||||
elif platform.system() == "Linux":
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
if "--ckpt" not in sys.argv:
|
||||
sys.argv.append("--ckpt")
|
||||
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
|
||||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from modules import errors
|
||||
|
||||
localizations = {}
|
||||
|
||||
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except Exception:
|
||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||
|
||||
return f"window.localization = {json.dumps(data)}"
|
||||
|
||||
@@ -15,6 +15,8 @@ def send_everything_to_cpu():
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
sd_model.lowvram = True
|
||||
|
||||
parents = {}
|
||||
|
||||
def send_me_to_gpu(module, _):
|
||||
@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.output_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
|
||||
def is_enabled(sd_model):
|
||||
return getattr(sd_model, 'lowvram', False)
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
import torch
|
||||
import platform
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# use check `getattr` and try it for compatibility.
|
||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
def check_for_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
@@ -43,7 +47,7 @@ if has_mps:
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
@@ -61,4 +65,4 @@ if has_mps:
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
if platform.processor() == 'i386':
|
||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
|
||||
@@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
def get_crop_region(mask, pad=0):
|
||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||
|
||||
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import glob
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
@@ -9,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
url: str,
|
||||
*,
|
||||
model_dir: str,
|
||||
progress: bool = True,
|
||||
file_name: str | None = None,
|
||||
) -> str:
|
||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if not file_name:
|
||||
parts = urlparse(url)
|
||||
file_name = os.path.basename(parts.path)
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
from torch.hub import download_url_to_file
|
||||
download_url_to_file(url, cached_file, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
@@ -40,16 +64,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||
continue
|
||||
if full_path not in output:
|
||||
output.append(full_path)
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
||||
else:
|
||||
output.append(model_url)
|
||||
|
||||
@@ -60,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
if "http" in file:
|
||||
if file.startswith("http"):
|
||||
file = urlparse(file).path
|
||||
|
||||
file = os.path.basename(file)
|
||||
@@ -96,8 +118,7 @@ def cleanup_models():
|
||||
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
try:
|
||||
if not os.path.exists(dest_path):
|
||||
os.makedirs(dest_path)
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
if os.path.exists(src_path):
|
||||
for file in os.listdir(src_path):
|
||||
fullpath = os.path.join(src_path, file)
|
||||
@@ -108,12 +129,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||
try:
|
||||
shutil.move(fullpath, dest_path)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if len(os.listdir(src_path)) == 0:
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -127,7 +148,7 @@ def load_upscalers():
|
||||
full_model = f"modules.{model_name}_model"
|
||||
try:
|
||||
importlib.import_module(full_model)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
datas = []
|
||||
@@ -145,7 +166,10 @@ def load_upscalers():
|
||||
for cls in reversed(used_classes.values()):
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
scaler = cls(commandline_options.get(cmd_name, None))
|
||||
commandline_model_path = commandline_options.get(cmd_name, None)
|
||||
scaler = cls(commandline_model_path)
|
||||
scaler.user_path = commandline_model_path
|
||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||
datas += scaler.scalers
|
||||
|
||||
shared.sd_upscalers = sorted(
|
||||
|
||||
@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||
|
||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||
if self.use_ema and not load_ema:
|
||||
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||
ignore_keys = ignore_keys or []
|
||||
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@@ -228,9 +230,9 @@ class DDPM(pl.LightningModule):
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
if missing:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
if unexpected:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
@@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||
log = dict()
|
||||
log = {}
|
||||
x = self.get_input(batch, self.first_stage_key)
|
||||
N = min(x.shape[0], N)
|
||||
n_row = min(x.shape[0], n_row)
|
||||
@@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
||||
log["inputs"] = x
|
||||
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
diffusion_row = []
|
||||
x_start = x[:n_row]
|
||||
|
||||
for t in range(self.num_timesteps):
|
||||
@@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
||||
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
try:
|
||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||
except:
|
||||
except Exception:
|
||||
self.num_downs = 0
|
||||
if not scale_by_std:
|
||||
self.scale_factor = scale_factor
|
||||
@@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
return self.p_losses(x, c, t, *args, **kwargs)
|
||||
|
||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
||||
def rescale_bbox(bbox):
|
||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
||||
return x0, y0, w, h
|
||||
|
||||
return [rescale_bbox(b) for b in bboxes]
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
@@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
||||
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||
else:
|
||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||
|
||||
@@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if i % log_every_t == 0 or i == timesteps - 1:
|
||||
intermediates.append(x0_partial)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(img, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if i % log_every_t == 0 or i == timesteps - 1:
|
||||
intermediates.append(img)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(img, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
|
||||
if return_intermediates:
|
||||
return img, intermediates
|
||||
@@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
||||
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||
else:
|
||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||
return self.p_sample_loop(cond,
|
||||
@@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
use_ddim = False
|
||||
|
||||
log = dict()
|
||||
log = {}
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
@@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
diffusion_row = []
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
@@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if inpaint:
|
||||
# make a simple center square
|
||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
||||
h, w = z.shape[2], z.shape[3]
|
||||
mask = torch.ones(N, h, w).to(self.device)
|
||||
# zeros will be filled in
|
||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||
@@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
||||
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
||||
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||
|
||||
key = 'train' if self.training else 'validation'
|
||||
dset = self.trainer.datamodule.datasets[key]
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .sampler import UniPCSampler
|
||||
from .sampler import UniPCSampler # noqa: F401
|
||||
|
||||
@@ -54,7 +54,8 @@ class UniPCSampler(object):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from tqdm.auto import trange
|
||||
import tqdm
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
@@ -179,13 +178,13 @@ def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
model_kwargs=None,
|
||||
guidance_type="uncond",
|
||||
#condition=None,
|
||||
#unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs={},
|
||||
classifier_kwargs=None,
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
|
||||
@@ -276,6 +275,9 @@ def model_wrapper(
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
classifier_kwargs = classifier_kwargs or {}
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
@@ -342,7 +344,7 @@ def model_wrapper(
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
c_in = {}
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
@@ -353,7 +355,7 @@ def model_wrapper(
|
||||
unconditional_condition[k],
|
||||
condition[k]])
|
||||
elif isinstance(condition, list):
|
||||
c_in = list()
|
||||
c_in = []
|
||||
assert isinstance(unconditional_condition, list)
|
||||
for i in range(len(condition)):
|
||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||
@@ -757,40 +759,44 @@ class UniPC:
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
for step in trange(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
with tqdm.tqdm(total=steps) as pbar:
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
pbar.update()
|
||||
|
||||
for step in range(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
pbar.update()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pyngrok import ngrok, conf, exception
|
||||
import ngrok
|
||||
|
||||
def connect(token, port, region):
|
||||
# Connect to ngrok for ingress
|
||||
def connect(token, port, options):
|
||||
account = None
|
||||
if token is None:
|
||||
token = 'None'
|
||||
@@ -10,28 +11,19 @@ def connect(token, port, region):
|
||||
token, username, password = token.split(':', 2)
|
||||
account = f"{username}:{password}"
|
||||
|
||||
config = conf.PyngrokConfig(
|
||||
auth_token=token, region=region
|
||||
)
|
||||
|
||||
# Guard for existing tunnels
|
||||
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||
if existing:
|
||||
for established in existing:
|
||||
# Extra configuration in the case that the user is also using ngrok for other tunnels
|
||||
if established.config['addr'][-4:] == str(port):
|
||||
public_url = existing[0].public_url
|
||||
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||
'You can use this link after the launch is complete.')
|
||||
return
|
||||
|
||||
# For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
||||
if not options.get('authtoken_from_env'):
|
||||
options['authtoken'] = token
|
||||
if account:
|
||||
options['basic_auth'] = account
|
||||
if not options.get('session_metadata'):
|
||||
options['session_metadata'] = 'stable-diffusion-webui'
|
||||
|
||||
|
||||
try:
|
||||
if account is None:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||
else:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
||||
except exception.PyngrokNgrokError:
|
||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
||||
except Exception as e:
|
||||
print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
|
||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||
else:
|
||||
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
|
||||
import modules.safe
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
@@ -39,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
class Prioritize:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.path = None
|
||||
|
||||
def __enter__(self):
|
||||
self.path = sys.path.copy()
|
||||
sys.path = [paths[self.name]] + sys.path
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.path = self.path
|
||||
self.path = None
|
||||
|
||||
@@ -2,8 +2,14 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
@@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
@@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
config_states_dir = os.path.join(script_path, "config_states")
|
||||
|
||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||
|
||||
@@ -9,8 +9,7 @@ from modules.shared import opts
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
shared.state.begin(job="extras")
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
for image, name in zip(image_data, image_names):
|
||||
shared.state.textinfo = name
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from PIL import Image, ImageOps
|
||||
import random
|
||||
import cv2
|
||||
from skimage import exposure
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@@ -24,13 +24,13 @@ import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
import modules.sd_vae as sd_vae
|
||||
import logging
|
||||
from ldm.data.util import AddMiDaS
|
||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
@@ -106,6 +106,9 @@ class StableDiffusionProcessing:
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
@@ -150,6 +153,8 @@ class StableDiffusionProcessing:
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
self.token_merging_ratio = 0
|
||||
self.token_merging_ratio_hr = 0
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
@@ -165,7 +170,21 @@ class StableDiffusionProcessing:
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
self.is_hr_pass = False
|
||||
|
||||
self.sampler = None
|
||||
|
||||
self.prompts = None
|
||||
self.negative_prompts = None
|
||||
self.extra_network_data = None
|
||||
self.seeds = None
|
||||
self.subseeds = None
|
||||
|
||||
self.step_multiplier = 1
|
||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||
self.cached_c = StableDiffusionProcessing.cached_c
|
||||
self.uc = None
|
||||
self.c = None
|
||||
|
||||
self.user = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
@@ -273,6 +292,64 @@ class StableDiffusionProcessing:
|
||||
|
||||
def close(self):
|
||||
self.sampler = None
|
||||
self.c = None
|
||||
self.uc = None
|
||||
if not opts.experimental_persistent_cond_cache:
|
||||
StableDiffusionProcessing.cached_c = [None, None]
|
||||
StableDiffusionProcessing.cached_uc = [None, None]
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
if for_hr:
|
||||
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
||||
|
||||
return self.token_merging_ratio or opts.token_merging_ratio
|
||||
|
||||
def setup_prompts(self):
|
||||
if type(self.prompt) == list:
|
||||
self.all_prompts = self.prompt
|
||||
else:
|
||||
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
|
||||
|
||||
if type(self.negative_prompt) == list:
|
||||
self.all_negative_prompts = self.negative_prompt
|
||||
else:
|
||||
self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
|
||||
|
||||
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
||||
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
||||
"""
|
||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||
using a cache to store the result if the same arguments have been used before.
|
||||
|
||||
cache is an array containing two elements. The first element is a tuple
|
||||
representing the previously used arguments, or None if no arguments
|
||||
have been used before. The second element is where the previously
|
||||
computed result is stored.
|
||||
|
||||
caches is a list with items described above.
|
||||
"""
|
||||
for cache in caches:
|
||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
||||
return cache[1]
|
||||
|
||||
cache = caches[0]
|
||||
|
||||
with devices.autocast():
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||
|
||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
||||
return cache[1]
|
||||
|
||||
def setup_conds(self):
|
||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||
|
||||
|
||||
class Processed:
|
||||
@@ -303,6 +380,8 @@ class Processed:
|
||||
self.styles = p.styles
|
||||
self.job_timestamp = state.job_timestamp
|
||||
self.clip_skip = opts.CLIP_stop_at_last_layers
|
||||
self.token_merging_ratio = p.token_merging_ratio
|
||||
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
||||
|
||||
self.eta = p.eta
|
||||
self.ddim_discretize = p.ddim_discretize
|
||||
@@ -310,6 +389,7 @@ class Processed:
|
||||
self.s_tmin = p.s_tmin
|
||||
self.s_tmax = p.s_tmax
|
||||
self.s_noise = p.s_noise
|
||||
self.s_min_uncond = p.s_min_uncond
|
||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
@@ -360,6 +440,9 @@ class Processed:
|
||||
def infotext(self, p: StableDiffusionProcessing, index):
|
||||
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
@@ -468,10 +551,17 @@ def program_version():
|
||||
return res
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
enable_hr = getattr(p, 'enable_hr', False)
|
||||
token_merging_ratio = p.get_token_merging_ratio()
|
||||
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
||||
|
||||
uses_ensd = opts.eta_noise_seed_delta != 0
|
||||
if uses_ensd:
|
||||
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
@@ -485,27 +575,33 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
"User": p.user if opts.add_user_name_to_info else None,
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.scripts is not None:
|
||||
p.scripts.before_process(p)
|
||||
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
@@ -523,9 +619,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
sd_models.apply_token_merging(p.sd_model, 0)
|
||||
|
||||
# restore opts to original state
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
@@ -555,15 +655,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
comments = {}
|
||||
|
||||
if type(p.prompt) == list:
|
||||
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
||||
else:
|
||||
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
||||
|
||||
if type(p.negative_prompt) == list:
|
||||
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
||||
else:
|
||||
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||
p.setup_prompts()
|
||||
|
||||
if type(seed) == list:
|
||||
p.all_seeds = seed
|
||||
@@ -575,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
@@ -587,29 +679,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
def get_conds_with_caching(function, required_prompts, steps, cache):
|
||||
"""
|
||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||
using a cache to store the result if the same arguments have been used before.
|
||||
|
||||
cache is an array containing two elements. The first element is a tuple
|
||||
representing the previously used arguments, or None if no arguments
|
||||
have been used before. The second element is where the previously
|
||||
computed result is stored.
|
||||
"""
|
||||
|
||||
if cache[0] is not None and (required_prompts, steps) == cache[0]:
|
||||
return cache[1]
|
||||
|
||||
with devices.autocast():
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||
|
||||
cache[0] = (required_prompts, steps)
|
||||
return cache[1]
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
@@ -618,10 +687,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||
sd_vae_approx.model()
|
||||
|
||||
sd_unet.apply_unet()
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
extra_network_data = None
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
@@ -631,25 +701,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
if len(prompts) == 0:
|
||||
if len(p.prompts) == 0:
|
||||
break
|
||||
|
||||
prompts, extra_network_data = extra_networks.parse_prompts(prompts)
|
||||
p.parse_extra_network_prompts()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
extra_networks.activate(p, p.extra_network_data)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
# params.txt should be saved after scripts.process_batch, since the
|
||||
# infotext could be modified by that callback
|
||||
@@ -660,14 +730,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
step_multiplier = 1
|
||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||
try:
|
||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||
except:
|
||||
pass
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||
p.setup_conds()
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
@@ -677,7 +740,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
for x in x_samples_ddim:
|
||||
@@ -688,7 +751,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
del samples_ddim
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
if lowvram.is_enabled(shared.sd_model):
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
devices.torch_gc()
|
||||
@@ -704,7 +767,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if p.restore_faces:
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -721,13 +784,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
||||
if opts.samples_save and not p.do_not_save_samples:
|
||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
||||
|
||||
text = infotext(n, i)
|
||||
infotexts.append(text)
|
||||
@@ -740,10 +803,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
|
||||
if opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
|
||||
if opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
||||
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
@@ -765,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
grid = images.image_grid(output_images, p.batch_size)
|
||||
|
||||
if opts.return_grid:
|
||||
text = infotext()
|
||||
text = infotext(use_main_prompt=True)
|
||||
infotexts.insert(0, text)
|
||||
if opts.enable_pnginfo:
|
||||
grid.info["parameters"] = text
|
||||
@@ -773,10 +836,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
index_of_first_image = 1
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
if not p.disable_extra_networks and extra_network_data:
|
||||
extra_networks.deactivate(p, extra_network_data)
|
||||
if not p.disable_extra_networks and p.extra_network_data:
|
||||
extra_networks.deactivate(p, p.extra_network_data)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -785,7 +848,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
images_list=output_images,
|
||||
seed=p.all_seeds[0],
|
||||
info=infotext(),
|
||||
comments="".join(f"\n\n{comment}" for comment in comments),
|
||||
comments="".join(f"{comment}\n" for comment in comments),
|
||||
subseed=p.all_subseeds[0],
|
||||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
@@ -811,8 +874,10 @@ def old_hires_fix_first_pass_dimensions(width, height):
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
@@ -823,6 +888,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
self.hr_sampler_name = hr_sampler_name
|
||||
self.hr_prompt = hr_prompt
|
||||
self.hr_negative_prompt = hr_negative_prompt
|
||||
self.all_hr_prompts = None
|
||||
self.all_hr_negative_prompts = None
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
self.hr_upscale_to_x = self.width
|
||||
@@ -834,8 +904,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
self.truncate_y = 0
|
||||
self.applied_old_hires_behavior_to = None
|
||||
|
||||
self.hr_prompts = None
|
||||
self.hr_negative_prompts = None
|
||||
self.hr_extra_network_data = None
|
||||
|
||||
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
||||
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
||||
self.hr_c = None
|
||||
self.hr_uc = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||
|
||||
if tuple(self.hr_prompt) != tuple(self.prompt):
|
||||
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
||||
|
||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
||||
|
||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||
self.hr_resize_x = self.width
|
||||
self.hr_resize_y = self.height
|
||||
@@ -901,7 +989,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and latent_scale_mode is None:
|
||||
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
|
||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
@@ -965,9 +1054,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
img2img_sampler_name = self.sampler_name
|
||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||
|
||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||
img2img_sampler_name = 'DDIM'
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||
@@ -978,17 +1069,101 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
if not self.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
extra_networks.activate(self, self.hr_extra_network_data)
|
||||
|
||||
with devices.autocast():
|
||||
self.calculate_hr_conds()
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return samples
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self.hr_c = None
|
||||
self.hr_uc = None
|
||||
if not opts.experimental_persistent_cond_cache:
|
||||
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
|
||||
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
|
||||
|
||||
def setup_prompts(self):
|
||||
super().setup_prompts()
|
||||
|
||||
if not self.enable_hr:
|
||||
return
|
||||
|
||||
if self.hr_prompt == '':
|
||||
self.hr_prompt = self.prompt
|
||||
|
||||
if self.hr_negative_prompt == '':
|
||||
self.hr_negative_prompt = self.negative_prompt
|
||||
|
||||
if type(self.hr_prompt) == list:
|
||||
self.all_hr_prompts = self.hr_prompt
|
||||
else:
|
||||
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
|
||||
|
||||
if type(self.hr_negative_prompt) == list:
|
||||
self.all_hr_negative_prompts = self.hr_negative_prompt
|
||||
else:
|
||||
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
|
||||
|
||||
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
|
||||
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
|
||||
|
||||
def calculate_hr_conds(self):
|
||||
if self.hr_c is not None:
|
||||
return
|
||||
|
||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||
|
||||
def setup_conds(self):
|
||||
super().setup_conds()
|
||||
|
||||
self.hr_uc = None
|
||||
self.hr_c = None
|
||||
|
||||
if self.enable_hr:
|
||||
if shared.opts.hires_fix_use_firstpass_conds:
|
||||
self.calculate_hr_conds()
|
||||
|
||||
elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
||||
with devices.autocast():
|
||||
extra_networks.activate(self, self.hr_extra_network_data)
|
||||
|
||||
self.calculate_hr_conds()
|
||||
|
||||
with devices.autocast():
|
||||
extra_networks.activate(self, self.extra_network_data)
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
res = super().parse_extra_network_prompts()
|
||||
|
||||
if self.enable_hr:
|
||||
self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
||||
self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
||||
|
||||
self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.init_images = init_images
|
||||
@@ -999,7 +1174,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
self.mask_for_overlay = None
|
||||
self.mask_blur = mask_blur
|
||||
if mask_blur is not None:
|
||||
mask_blur_x = mask_blur
|
||||
mask_blur_y = mask_blur
|
||||
self.mask_blur_x = mask_blur_x
|
||||
self.mask_blur_y = mask_blur_y
|
||||
self.inpainting_fill = inpainting_fill
|
||||
self.inpaint_full_res = inpaint_full_res
|
||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
||||
@@ -1021,8 +1200,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if self.inpainting_mask_invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
if self.mask_blur > 0:
|
||||
image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
if self.mask_blur_x > 0:
|
||||
np_mask = np.array(image_mask)
|
||||
kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
|
||||
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
||||
image_mask = Image.fromarray(np_mask)
|
||||
|
||||
if self.mask_blur_y > 0:
|
||||
np_mask = np.array(image_mask)
|
||||
kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
|
||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||
image_mask = Image.fromarray(np_mask)
|
||||
|
||||
if self.inpaint_full_res:
|
||||
self.mask_for_overlay = image_mask
|
||||
@@ -1141,3 +1329,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
devices.torch_gc()
|
||||
|
||||
return samples
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|
||||
|
||||
@@ -95,9 +95,20 @@ def progressapi(req: ProgressRequest):
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="png")
|
||||
|
||||
if opts.live_previews_image_format == "png":
|
||||
# using optimize for large images takes an enormous amount of time
|
||||
if max(*image.size) <= 256:
|
||||
save_kwargs = {"optimize": True}
|
||||
else:
|
||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||
|
||||
else:
|
||||
save_kwargs = {}
|
||||
|
||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
live_preview = f"data:image/png;base64,{base64_image}"
|
||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
else:
|
||||
live_preview = None
|
||||
|
||||
@@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
"""
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
l = [steps]
|
||||
res = [steps]
|
||||
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
tree.children[-1] = float(tree.children[-1])
|
||||
if tree.children[-1] < 1:
|
||||
tree.children[-1] *= steps
|
||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||
l.append(tree.children[-1])
|
||||
res.append(tree.children[-1])
|
||||
|
||||
def alternate(self, tree):
|
||||
l.extend(range(1, steps+1))
|
||||
res.extend(range(1, steps+1))
|
||||
|
||||
CollectSteps().visit(tree)
|
||||
return sorted(set(l))
|
||||
return sorted(set(res))
|
||||
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
def get_schedule(prompt):
|
||||
try:
|
||||
tree = schedule_parser.parse(prompt)
|
||||
except lark.exceptions.LarkError as e:
|
||||
except lark.exceptions.LarkError:
|
||||
if 0:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||
if current_step <= end_at:
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
tensors = []
|
||||
conds_list = []
|
||||
|
||||
for batch_no, composable_prompts in enumerate(c.batch):
|
||||
for composable_prompts in c.batch:
|
||||
conds_for_batch = []
|
||||
|
||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
||||
for composable_prompt in composable_prompts:
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
||||
if current_step <= end_at:
|
||||
for current, entry in enumerate(composable_prompt.schedules):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
|
||||
@@ -333,11 +336,11 @@ def parse_prompt_attention(text):
|
||||
round_brackets.append(len(res))
|
||||
elif text == '[':
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
elif weight is not None and round_brackets:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ')' and len(round_brackets) > 0:
|
||||
elif text == ')' and round_brackets:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == ']' and len(square_brackets) > 0:
|
||||
elif text == ']' and square_brackets:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
parts = re.split(re_break, text)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader
|
||||
from modules import modelloader, errors
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
@@ -17,9 +15,9 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||
from realesrgan import RealESRGANer # noqa: F401
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
@@ -36,8 +34,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.scalers.append(scaler)
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
||||
self.enable = False
|
||||
self.scalers = []
|
||||
|
||||
@@ -45,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.local_data_path):
|
||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
@@ -65,21 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
if info.local_data_path.startswith("http"):
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
@@ -134,6 +128,5 @@ def get_realesrgan_models(scaler):
|
||||
),
|
||||
]
|
||||
return models
|
||||
except Exception as e:
|
||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
except Exception:
|
||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||
|
||||
23
modules/restart.py
Normal file
23
modules/restart.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
def is_restartable() -> bool:
|
||||
"""
|
||||
Return True if the webui is restartable (i.e. there is something watching to restart it with)
|
||||
"""
|
||||
return bool(os.environ.get('SD_WEBUI_RESTART'))
|
||||
|
||||
|
||||
def restart_program() -> None:
|
||||
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
||||
|
||||
(Path(script_path) / "tmp" / "restart").touch()
|
||||
|
||||
stop_program()
|
||||
|
||||
|
||||
def stop_program() -> None:
|
||||
os._exit(0)
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
@@ -11,7 +9,10 @@ import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
from modules import errors
|
||||
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
def encode(*args):
|
||||
@@ -95,16 +96,16 @@ def check_pt(filename, extra_handler):
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
f"The file may be malicious, so the program is not going to read it.\n"
|
||||
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
@@ -190,4 +194,3 @@ with safe.Extra(handler):
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
from modules import errors, timer
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
@@ -32,27 +32,42 @@ class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
self.text_cond = text_cond
|
||||
""" Encoder hidden states of text conditioning from prompt"""
|
||||
|
||||
|
||||
self.text_uncond = text_uncond
|
||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
self.inner_model = inner_model
|
||||
"""Inner model reference used for denoising"""
|
||||
|
||||
|
||||
class AfterCFGCallbackParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
@@ -87,6 +102,7 @@ callback_map = dict(
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_cfg_denoised=[],
|
||||
callbacks_cfg_after_cfg=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
@@ -94,6 +110,8 @@ callback_map = dict(
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
callbacks_on_reload=[],
|
||||
callbacks_list_optimizers=[],
|
||||
callbacks_list_unets=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -106,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
timer.startup_timer.record(os.path.basename(c.script))
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
@@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
||||
report_exception(c, 'cfg_denoised_callback')
|
||||
|
||||
|
||||
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||
for c in callback_map['callbacks_cfg_after_cfg']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_after_cfg_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
@@ -234,16 +261,40 @@ def before_ui_callback():
|
||||
report_exception(c, 'before_ui')
|
||||
|
||||
|
||||
def list_optimizers_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_list_optimizers']:
|
||||
try:
|
||||
c.callback(res)
|
||||
except Exception:
|
||||
report_exception(c, 'list_optimizers')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def list_unets_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_list_unets']:
|
||||
try:
|
||||
c.callback(res)
|
||||
except Exception:
|
||||
report_exception(c, 'list_unets')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
@@ -332,6 +383,14 @@ def on_cfg_denoised(callback):
|
||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||
|
||||
|
||||
def on_cfg_after_cfg(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||
The callback is called with one argument:
|
||||
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
@@ -377,3 +436,18 @@ def on_before_ui(callback):
|
||||
"""register a function to be called before the UI is created."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_ui'], callback)
|
||||
|
||||
|
||||
def on_list_optimizers(callback):
|
||||
"""register a function to be called when UI is making a list of cross attention optimization options.
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
||||
to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
||||
|
||||
|
||||
def on_list_unets(callback):
|
||||
"""register a function to be called when UI is making a list of alternative options for unet.
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
from modules import errors
|
||||
|
||||
|
||||
def load_module(path):
|
||||
@@ -28,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running preload() for {preload_script}", exc_info=True)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
@@ -17,6 +17,12 @@ class PostprocessImageArgs:
|
||||
|
||||
|
||||
class Script:
|
||||
name = None
|
||||
"""script's internal name derived from title"""
|
||||
|
||||
section = None
|
||||
"""name of UI section that the script's controls will be placed into"""
|
||||
|
||||
filename = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
@@ -25,8 +31,8 @@ class Script:
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
@@ -38,6 +44,9 @@ class Script:
|
||||
various "Send to <X>" buttons when clicked
|
||||
"""
|
||||
|
||||
api_info = None
|
||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
@@ -76,6 +85,15 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
"""
|
||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
@@ -99,6 +117,21 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||
"""
|
||||
Calledafter extra networks activation, before conds calculation
|
||||
allow modification of the network after extra networks activation been applied
|
||||
won't be call if p.disable_extra_networks
|
||||
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
- extra_network_data - list of ExtraNetworkParams for current stage
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
@@ -169,6 +202,11 @@ class Script:
|
||||
|
||||
return f'script_{tabname}{title}_{item_id}'
|
||||
|
||||
def before_hr(self, p, *args):
|
||||
"""
|
||||
This function is called before hires fix start.
|
||||
"""
|
||||
pass
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
@@ -231,8 +269,8 @@ def load_scripts():
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) != type:
|
||||
for script_class in module.__dict__.values():
|
||||
if not inspect.isclass(script_class):
|
||||
continue
|
||||
|
||||
if issubclass(script_class, Script):
|
||||
@@ -258,21 +296,25 @@ def load_scripts():
|
||||
register_scripts_from_module(script_module)
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
|
||||
|
||||
finally:
|
||||
sys.path = syspath
|
||||
current_basedir = paths.script_path
|
||||
timer.startup_timer.record(scriptfile.filename)
|
||||
|
||||
global scripts_txt2img, scripts_img2img, scripts_postproc
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
return func(*args, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
|
||||
|
||||
return default
|
||||
|
||||
@@ -285,6 +327,7 @@ class ScriptRunner:
|
||||
self.titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.inputs = [None]
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
@@ -295,9 +338,9 @@ class ScriptRunner:
|
||||
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
for script_data in auto_processing_scripts + scripts_data:
|
||||
script = script_data.script_class()
|
||||
script.filename = script_data.path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
|
||||
@@ -312,48 +355,73 @@ class ScriptRunner:
|
||||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script):
|
||||
import modules.api.models as api_models
|
||||
|
||||
script.args_from = len(self.inputs)
|
||||
script.args_to = len(self.inputs)
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
|
||||
if controls is None:
|
||||
return
|
||||
|
||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||
api_args = []
|
||||
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
|
||||
api_args.append(arg_info)
|
||||
|
||||
script.api_info = api_models.ScriptInfo(
|
||||
name=script.name,
|
||||
is_img2img=script.is_img2img,
|
||||
is_alwayson=script.alwayson,
|
||||
args=api_args,
|
||||
)
|
||||
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
|
||||
if script.paste_field_names is not None:
|
||||
self.paste_field_names += script.paste_field_names
|
||||
|
||||
self.inputs += controls
|
||||
script.args_to = len(self.inputs)
|
||||
|
||||
def setup_ui_for_section(self, section, scriptlist=None):
|
||||
if scriptlist is None:
|
||||
scriptlist = self.alwayson_scripts
|
||||
|
||||
for script in scriptlist:
|
||||
if script.alwayson and script.section != section:
|
||||
continue
|
||||
|
||||
with gr.Group(visible=script.alwayson) as group:
|
||||
self.create_script_ui(script)
|
||||
|
||||
script.group = group
|
||||
|
||||
def prepare_ui(self):
|
||||
self.inputs = [None]
|
||||
|
||||
def setup_ui(self):
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
inputs = [None]
|
||||
inputs_alwayson = [True]
|
||||
|
||||
def create_script_ui(script, inputs, inputs_alwayson):
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
|
||||
if controls is None:
|
||||
return
|
||||
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
|
||||
if script.paste_field_names is not None:
|
||||
self.paste_field_names += script.paste_field_names
|
||||
|
||||
inputs += controls
|
||||
inputs_alwayson += [script.alwayson for _ in controls]
|
||||
script.args_to = len(inputs)
|
||||
|
||||
for script in self.alwayson_scripts:
|
||||
with gr.Group() as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
script.group = group
|
||||
self.setup_ui_for_section(None)
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||
inputs[0] = dropdown
|
||||
self.inputs[0] = dropdown
|
||||
|
||||
for script in self.selectable_scripts:
|
||||
with gr.Group(visible=False) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
script.group = group
|
||||
self.setup_ui_for_section(None, self.selectable_scripts)
|
||||
|
||||
def select_script(script_index):
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||
@@ -378,6 +446,7 @@ class ScriptRunner:
|
||||
)
|
||||
|
||||
self.script_load_ctr = 0
|
||||
|
||||
def onload_script_visibility(params):
|
||||
title = params.get('Script', None)
|
||||
if title:
|
||||
@@ -388,10 +457,10 @@ class ScriptRunner:
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
|
||||
self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
|
||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||
|
||||
return inputs
|
||||
return self.inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
@@ -411,14 +480,21 @@ class ScriptRunner:
|
||||
|
||||
return processed
|
||||
|
||||
def before_process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.before_process(p, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
|
||||
|
||||
def process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process(p, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running process: {script.filename}", exc_info=True)
|
||||
|
||||
def before_process_batch(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
@@ -426,8 +502,15 @@ class ScriptRunner:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.before_process_batch(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||
|
||||
def after_extra_networks_activate(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
@@ -435,8 +518,7 @@ class ScriptRunner:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process_batch(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
@@ -444,8 +526,7 @@ class ScriptRunner:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess(p, processed, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
@@ -453,8 +534,7 @@ class ScriptRunner:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
@@ -462,24 +542,21 @@ class ScriptRunner:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image(p, pp, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
@@ -492,7 +569,7 @@ class ScriptRunner:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
@@ -500,9 +577,18 @@ class ScriptRunner:
|
||||
self.scripts[si].args_to = args_to
|
||||
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
def before_hr(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.before_hr(p, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
|
||||
scripts_current: ScriptRunner = None
|
||||
|
||||
|
||||
@@ -512,14 +598,7 @@ def reload_script_body_only():
|
||||
scripts_img2img.reload_sources(cache)
|
||||
|
||||
|
||||
def reload_scripts():
|
||||
global scripts_txt2img, scripts_img2img, scripts_postproc
|
||||
|
||||
load_scripts()
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
|
||||
reload_scripts = load_scripts # compatibility alias
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
|
||||
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
|
||||
@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
||||
def initialize_scripts(self, scripts_data):
|
||||
self.scripts = []
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
for script_data in scripts_data:
|
||||
script: ScriptPostprocessing = script_data.script_class()
|
||||
script.filename = script_data.path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, component), value in zip(script.controls.items(), script_args):
|
||||
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
script.process(pp, **process_args)
|
||||
|
||||
@@ -61,7 +61,7 @@ class DisableInitialization:
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
|
||||
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
@@ -28,57 +28,65 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
|
||||
def list_optimizers():
|
||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
|
||||
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
|
||||
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
|
||||
optimizers.clear()
|
||||
optimizers.extend(new_optimizers)
|
||||
|
||||
|
||||
def apply_optimizations(option=None):
|
||||
global current_optimizer
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
||||
if len(optimizers) == 0:
|
||||
# a script can access the model very early, and optimizations would not be filled by then
|
||||
current_optimizer = None
|
||||
return ''
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
optimization_method = None
|
||||
|
||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
||||
if current_optimizer is not None:
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
||||
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
|
||||
optimization_method = 'sdp-no-mem'
|
||||
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
||||
print("Applying scaled dot product cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
|
||||
optimization_method = 'sdp'
|
||||
elif cmd_opts.opt_sub_quad_attention:
|
||||
print("Applying sub-quadratic cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||
optimization_method = 'sub-quadratic'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
optimization_method = 'Doggettx'
|
||||
selection = option or shared.opts.cross_attention_optimization
|
||||
if selection == "Automatic" and len(optimizers) > 0:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||
else:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
|
||||
|
||||
return optimization_method
|
||||
if selection == "None":
|
||||
matching_optimizer = None
|
||||
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
|
||||
matching_optimizer = None
|
||||
elif matching_optimizer is None:
|
||||
matching_optimizer = optimizers[0]
|
||||
|
||||
if matching_optimizer is not None:
|
||||
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
||||
matching_optimizer.apply()
|
||||
print("done.")
|
||||
current_optimizer = matching_optimizer
|
||||
return current_optimizer.name
|
||||
else:
|
||||
print("Disabling attention optimization")
|
||||
return ''
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
@@ -92,12 +100,12 @@ def fix_checkpoint():
|
||||
def weighted_loss(sd_model, pred, target, mean=True):
|
||||
#Calculate the weight normally, but ignore the mean
|
||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
|
||||
|
||||
#Check if we have weights available
|
||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
if weight is not None:
|
||||
loss *= weight
|
||||
|
||||
|
||||
#Return the loss, as mean if specified
|
||||
return loss.mean() if mean else loss
|
||||
|
||||
@@ -105,7 +113,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Temporarily append weights to a place accessible during loss calc
|
||||
sd_model._custom_loss_weight = w
|
||||
|
||||
|
||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
if not hasattr(sd_model, '_old_get_loss'):
|
||||
@@ -118,9 +126,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Delete temporary weights if appended
|
||||
del sd_model._custom_loss_weight
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
#If we have an old loss function, reset the loss function to the original one
|
||||
if hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model.get_loss = sd_model._old_get_loss
|
||||
@@ -133,7 +141,7 @@ def apply_weighted_forward(sd_model):
|
||||
def undo_weighted_forward(sd_model):
|
||||
try:
|
||||
del sd_model.weighted_forward
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -150,6 +158,13 @@ class StableDiffusionModelHijack:
|
||||
def __init__(self):
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
try:
|
||||
self.optimization_method = apply_optimizations(option)
|
||||
except Exception as e:
|
||||
errors.display(e, "applying cross attention optimization")
|
||||
undo_optimizations()
|
||||
|
||||
def hijack(self, m):
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
@@ -169,7 +184,7 @@ class StableDiffusionModelHijack:
|
||||
if m.cond_stage_key == "edit":
|
||||
sd_hijack_unet.hijack_ddpm_edit()
|
||||
|
||||
self.optimization_method = apply_optimizations()
|
||||
self.apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
@@ -182,9 +197,14 @@ class StableDiffusionModelHijack:
|
||||
|
||||
self.layers = flatten(m)
|
||||
|
||||
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
||||
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
@@ -203,6 +223,8 @@ class StableDiffusionModelHijack:
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
@@ -216,10 +238,17 @@ class StableDiffusionModelHijack:
|
||||
self.comments = []
|
||||
|
||||
def get_prompt_lengths(self, text):
|
||||
if self.clip is None:
|
||||
return "-", "-"
|
||||
|
||||
_, token_count = self.clip.process_texts([text])
|
||||
|
||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
def redo_hijack(self, m):
|
||||
self.undo_hijack(m)
|
||||
self.hijack(m)
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
|
||||
@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
chunk.multipliers += [weight] * emb_len
|
||||
position += embedding_length_in_tokens
|
||||
|
||||
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||
if chunk.tokens or not chunks:
|
||||
next_chunk(is_last=True)
|
||||
|
||||
return chunks, token_count
|
||||
@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
for fixes in self.hijack.fixes:
|
||||
for position, embedding in fixes:
|
||||
for _position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
|
||||
@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
if used_custom_terms:
|
||||
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
from ldm.models.diffusion.ddim import noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
@@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
c_in = {}
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
|
||||
|
||||
def should_hijack_ip2p(checkpoint_info):
|
||||
from modules import sd_models_config
|
||||
@@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
|
||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
||||
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import sys
|
||||
import traceback
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
@@ -9,10 +8,129 @@ from torch import einsum
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
from modules import shared, errors, devices
|
||||
from modules import shared, errors, devices, sub_quadratic_attention
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
class SdOptimization:
|
||||
name: str = None
|
||||
label: str | None = None
|
||||
cmd_opt: str | None = None
|
||||
priority: int = 0
|
||||
|
||||
def title(self):
|
||||
if self.label is None:
|
||||
return self.name
|
||||
|
||||
return f"{self.name} - {self.label}"
|
||||
|
||||
def is_available(self):
|
||||
return True
|
||||
|
||||
def apply(self):
|
||||
pass
|
||||
|
||||
def undo(self):
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
class SdOptimizationXformers(SdOptimization):
|
||||
name = "xformers"
|
||||
cmd_opt = "xformers"
|
||||
priority = 100
|
||||
|
||||
def is_available(self):
|
||||
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdpNoMem(SdOptimization):
|
||||
name = "sdp-no-mem"
|
||||
label = "scaled dot product without memory efficient attention"
|
||||
cmd_opt = "opt_sdp_no_mem_attention"
|
||||
priority = 80
|
||||
|
||||
def is_available(self):
|
||||
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
name = "sdp"
|
||||
label = "scaled dot product"
|
||||
cmd_opt = "opt_sdp_attention"
|
||||
priority = 70
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
name = "sub-quadratic"
|
||||
cmd_opt = "opt_sub_quad_attention"
|
||||
priority = 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationV1(SdOptimization):
|
||||
name = "V1"
|
||||
label = "original v1"
|
||||
cmd_opt = "opt_split_attention_v1"
|
||||
priority = 10
|
||||
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
||||
class SdOptimizationInvokeAI(SdOptimization):
|
||||
name = "InvokeAI"
|
||||
cmd_opt = "opt_split_attention_invokeai"
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return 1000 if not torch.cuda.is_available() else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
|
||||
|
||||
class SdOptimizationDoggettx(SdOptimization):
|
||||
name = "Doggettx"
|
||||
cmd_opt = "opt_split_attention"
|
||||
priority = 90
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
|
||||
def list_optimizers(res):
|
||||
res.extend([
|
||||
SdOptimizationXformers(),
|
||||
SdOptimizationSdpNoMem(),
|
||||
SdOptimizationSdp(),
|
||||
SdOptimizationSubQuad(),
|
||||
SdOptimizationV1(),
|
||||
SdOptimizationInvokeAI(),
|
||||
SdOptimizationDoggettx(),
|
||||
])
|
||||
|
||||
|
||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
@@ -20,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
import xformers.ops
|
||||
shared.xformers_available = True
|
||||
except Exception:
|
||||
print("Cannot import xformers", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Cannot import xformers", exc_info=True)
|
||||
|
||||
|
||||
def get_available_vram():
|
||||
@@ -49,7 +166,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -62,10 +179,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
@@ -95,43 +212,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = r1.to(dtype)
|
||||
@@ -228,8 +345,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
r = r.to(dtype)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
@@ -296,11 +413,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
return efficient_dot_product_attention(
|
||||
return sub_quadratic_attention.efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -335,7 +451,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -370,7 +486,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -452,7 +568,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
@@ -461,7 +577,7 @@ def xformers_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
@@ -483,10 +599,10 @@ def sdp_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
@@ -507,7 +623,7 @@ def sub_quad_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
|
||||
@@ -14,10 +14,10 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules.paths import models_path
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@@ -87,8 +87,7 @@ class CheckpointInfo:
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging, CLIPModel
|
||||
from transformers import logging, CLIPModel # noqa: F401
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
@@ -96,10 +95,8 @@ except Exception:
|
||||
|
||||
|
||||
def setup_model():
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
list_models()
|
||||
enable_midas_autodownload()
|
||||
|
||||
|
||||
@@ -166,21 +163,22 @@ def model_hash(filename):
|
||||
|
||||
|
||||
def select_checkpoint():
|
||||
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
|
||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
if len(checkpoints_list) == 0:
|
||||
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
error_message = "No checkpoints found. When searching for checkpoints, looked at:"
|
||||
if shared.cmd_opts.ckpt is not None:
|
||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
|
||||
error_message += f"\n - directory {model_path}"
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
|
||||
error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
|
||||
raise FileNotFoundError(error_message)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
if model_checkpoint is not None:
|
||||
@@ -239,7 +237,7 @@ def read_metadata_from_safetensors(filename):
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return res
|
||||
@@ -249,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
|
||||
if not shared.opts.disable_mmap_load_safetensors:
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
@@ -315,8 +318,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
@@ -374,7 +375,7 @@ def enable_midas_autodownload():
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
mkdir(midas_path)
|
||||
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
@@ -410,15 +411,22 @@ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_w
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.was_loaded_at_least_once = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_sd_model(self):
|
||||
if self.was_loaded_at_least_once:
|
||||
return self.sd_model
|
||||
|
||||
if self.sd_model is None:
|
||||
with self.lock:
|
||||
if self.sd_model is not None or self.was_loaded_at_least_once:
|
||||
return self.sd_model
|
||||
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||
self.sd_model = None
|
||||
@@ -467,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if sd_model is None:
|
||||
@@ -493,6 +501,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
sd_model.eval()
|
||||
model_data.sd_model = sd_model
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
@@ -502,6 +511,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
@@ -521,6 +535,8 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
sd_unet.apply_unet("None")
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
@@ -538,13 +554,12 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return model_data.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
@@ -565,7 +580,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
from modules import devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
@@ -580,3 +595,29 @@ def unload_model_weights(sd_model=None, info=None):
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
"""
|
||||
Applies speed and memory optimizations from tomesd.
|
||||
"""
|
||||
|
||||
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
||||
|
||||
if current_token_merging_ratio == token_merging_ratio:
|
||||
return
|
||||
|
||||
if current_token_merging_ratio > 0:
|
||||
tomesd.remove_patch(sd_model)
|
||||
|
||||
if token_merging_ratio > 0:
|
||||
tomesd.apply_patch(
|
||||
sd_model,
|
||||
ratio=token_merging_ratio,
|
||||
use_rand=False, # can cause issues with some samplers
|
||||
merge_attn=True,
|
||||
merge_crossattn=False,
|
||||
merge_mlp=False
|
||||
)
|
||||
|
||||
sd_model.applied_token_merged_ratio = token_merging_ratio
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
@@ -14,12 +14,18 @@ samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
def find_sampler_config(name):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
config = find_sampler_config(name)
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
sampler = config.constructor(model)
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx
|
||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
||||
elif approximation == 3:
|
||||
x_sample = sample * 1.5
|
||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
@@ -58,6 +62,25 @@ def store_latent(decoded):
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
def is_sampler_using_eta_noise_seed_delta(p):
|
||||
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||
|
||||
eta = p.eta
|
||||
|
||||
if eta is None and p.sampler is not None:
|
||||
eta = p.sampler.eta
|
||||
|
||||
if eta is None and sampler_config is not None:
|
||||
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
||||
|
||||
if eta == 0:
|
||||
return False
|
||||
|
||||
return sampler_config.options.get("uses_ensd", False)
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||
]
|
||||
@@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
if self.is_ddim:
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
else:
|
||||
self.eta = 0.0
|
||||
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import einops
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
@@ -9,25 +8,28 @@ from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
@@ -42,6 +44,14 @@ sampler_extra_params = {
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
k_diffusion_scheduler = {
|
||||
'Automatic': None,
|
||||
'karras': k_diffusion.sampling.get_sigmas_karras,
|
||||
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
||||
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
|
||||
}
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
@@ -59,6 +69,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
@@ -87,17 +98,17 @@ class CFGDenoiser(torch.nn.Module):
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
@@ -123,6 +134,18 @@ class CFGDenoiser(torch.nn.Module):
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
@@ -161,7 +184,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
@@ -181,6 +204,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
@@ -224,7 +251,7 @@ class KDiffusionSampler:
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
|
||||
@@ -249,6 +276,13 @@ class KDiffusionSampler:
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
@@ -288,6 +322,31 @@ class KDiffusionSampler:
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif opts.k_sched_type != "Automatic":
|
||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||
sigmas_kwargs = {
|
||||
'sigma_min': sigma_min,
|
||||
'sigma_max': sigma_max,
|
||||
}
|
||||
|
||||
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
|
||||
p.extra_generation_params["Schedule type"] = opts.k_sched_type
|
||||
|
||||
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
|
||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
||||
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
|
||||
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
||||
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
||||
|
||||
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
|
||||
|
||||
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
||||
sigmas_kwargs['rho'] = opts.rho
|
||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||
|
||||
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
@@ -317,7 +376,7 @@ class KDiffusionSampler:
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
@@ -333,22 +392,25 @@ class KDiffusionSampler:
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
if self.funcname == 'sample_dpmpp_sde':
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
@@ -369,18 +431,21 @@ class KDiffusionSampler:
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.funcname == 'sample_dpmpp_sde':
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
92
modules/sd_unet.py
Normal file
92
modules/sd_unet.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch.nn
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
from modules import script_callbacks, shared, devices
|
||||
|
||||
unet_options = []
|
||||
current_unet_option = None
|
||||
current_unet = None
|
||||
|
||||
|
||||
def list_unets():
|
||||
new_unets = script_callbacks.list_unets_callback()
|
||||
|
||||
unet_options.clear()
|
||||
unet_options.extend(new_unets)
|
||||
|
||||
|
||||
def get_unet_option(option=None):
|
||||
option = option or shared.opts.sd_unet
|
||||
|
||||
if option == "None":
|
||||
return None
|
||||
|
||||
if option == "Automatic":
|
||||
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||
|
||||
options = [x for x in unet_options if x.model_name == name]
|
||||
|
||||
option = options[0].label if options else "None"
|
||||
|
||||
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||
|
||||
|
||||
def apply_unet(option=None):
|
||||
global current_unet_option
|
||||
global current_unet
|
||||
|
||||
new_option = get_unet_option(option)
|
||||
if new_option == current_unet_option:
|
||||
return
|
||||
|
||||
if current_unet is not None:
|
||||
print(f"Dectivating unet: {current_unet.option.label}")
|
||||
current_unet.deactivate()
|
||||
|
||||
current_unet_option = new_option
|
||||
if current_unet_option is None:
|
||||
current_unet = None
|
||||
|
||||
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||
|
||||
return
|
||||
|
||||
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
current_unet = current_unet_option.create_unet()
|
||||
current_unet.option = current_unet_option
|
||||
print(f"Activating unet: {current_unet.option.label}")
|
||||
current_unet.activate()
|
||||
|
||||
|
||||
class SdUnetOption:
|
||||
model_name = None
|
||||
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||
|
||||
label = None
|
||||
"""name of the unet in UI"""
|
||||
|
||||
def create_unet(self):
|
||||
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SdUnet(torch.nn.Module):
|
||||
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def activate(self):
|
||||
pass
|
||||
|
||||
def deactivate(self):
|
||||
pass
|
||||
|
||||
|
||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||
if current_unet is not None:
|
||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||
|
||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
@@ -88,10 +85,10 @@ def refresh_vae_list():
|
||||
|
||||
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||
for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
|
||||
if os.path.isfile(vae_location):
|
||||
return vae_location
|
||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||
for vae_file in vae_dict.values():
|
||||
if os.path.basename(vae_file).startswith(checkpoint_path):
|
||||
return vae_file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
88
modules/sd_vae_taesd.py
Normal file
88
modules/sd_vae_taesd.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Tiny AutoEncoder for Stable Diffusion
|
||||
(DNN for encoding / decoding SD's latent space)
|
||||
|
||||
https://github.com/madebyollin/taesd
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
|
||||
sd_vae_taesd = None
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
|
||||
class Clamp(nn.Module):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
def decoder():
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.decoder = decoder()
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def download_model(model_path):
|
||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading TAESD decoder to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_taesd
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||
download_model(model_path)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
||||
@@ -1,13 +1,14 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import modules.interrogate
|
||||
@@ -15,8 +16,11 @@ import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
demo = None
|
||||
|
||||
@@ -44,19 +48,6 @@ restricted_opts = {
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"batch",
|
||||
"override_settings",
|
||||
"scripts",
|
||||
]
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/glass",
|
||||
@@ -77,6 +68,9 @@ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_op
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
device = devices.device
|
||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
@@ -113,14 +107,56 @@ class State:
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
server_start = None
|
||||
_server_command_signal = threading.Event()
|
||||
_server_command: Optional[str] = None
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
# Compatibility getter for need_restart.
|
||||
return self.server_command == "restart"
|
||||
|
||||
@need_restart.setter
|
||||
def need_restart(self, value: bool) -> None:
|
||||
# Compatibility setter for need_restart.
|
||||
if value:
|
||||
self.server_command = "restart"
|
||||
|
||||
@property
|
||||
def server_command(self):
|
||||
return self._server_command
|
||||
|
||||
@server_command.setter
|
||||
def server_command(self, value: Optional[str]) -> None:
|
||||
"""
|
||||
Set the server command to `value` and signal that it's been set.
|
||||
"""
|
||||
self._server_command = value
|
||||
self._server_command_signal.set()
|
||||
|
||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
Wait for server command to get set; return and clear the value and signal.
|
||||
"""
|
||||
if self._server_command_signal.wait(timeout):
|
||||
self._server_command_signal.clear()
|
||||
req = self._server_command
|
||||
self._server_command = None
|
||||
return req
|
||||
return None
|
||||
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def nextjob(self):
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
@@ -144,7 +180,7 @@ class State:
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
@@ -158,10 +194,13 @@ class State:
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
@@ -202,8 +241,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
@@ -212,9 +252,37 @@ class OptionInfo:
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
|
||||
self.comment_after = comment_after
|
||||
"""HTML text that will be added before label in UI"""
|
||||
|
||||
def link(self, label, url):
|
||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def js(self, label, js_func):
|
||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def info(self, info):
|
||||
self.comment_after += f"<span class='info'>({info})</span>"
|
||||
return self
|
||||
|
||||
def html(self, html):
|
||||
self.comment_after += html
|
||||
return self
|
||||
|
||||
def needs_restart(self):
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for k, v in options_dict.items():
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
|
||||
return options_dict
|
||||
@@ -243,7 +311,7 @@ options_templates = {}
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
@@ -251,7 +319,12 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
@@ -262,10 +335,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
@@ -293,31 +366,31 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
@@ -339,20 +412,31 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
@@ -361,89 +445,109 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
||||
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
||||
</ul>"""),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||
}))
|
||||
|
||||
@@ -460,6 +564,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
@@ -553,6 +658,10 @@ class Options:
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
||||
# 1.4.0 ui_reorder
|
||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||
|
||||
bad_settings = 0
|
||||
for k, v in self.data.items():
|
||||
info = self.data_labels.get(k, None)
|
||||
@@ -571,7 +680,9 @@ class Options:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
@@ -582,11 +693,11 @@ class Options:
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for k, item in settings_items:
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
@@ -722,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
@@ -748,11 +863,17 @@ def walk_files(path, allowed_extensions=None):
|
||||
if allowed_extensions is not None:
|
||||
allowed_extensions = set(allowed_extensions)
|
||||
|
||||
for root, dirs, files in os.walk(path):
|
||||
for filename in files:
|
||||
items = list(os.walk(path, followlinks=True))
|
||||
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
||||
|
||||
for root, _, files in items:
|
||||
for filename in sorted(files, key=natural_sort_key):
|
||||
if allowed_extensions is not None:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext not in allowed_extensions:
|
||||
continue
|
||||
|
||||
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||
continue
|
||||
|
||||
yield os.path.join(root, filename)
|
||||
|
||||
@@ -21,3 +21,49 @@ def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
|
||||
|
||||
def cross_attention_optimizations():
|
||||
import modules.sd_hijack
|
||||
|
||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||
|
||||
|
||||
def sd_unet_items():
|
||||
import modules.sd_unet
|
||||
|
||||
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
|
||||
|
||||
|
||||
def refresh_unet_list():
|
||||
import modules.sd_unet
|
||||
|
||||
modules.sd_unet.list_unets()
|
||||
|
||||
|
||||
ui_reorder_categories_builtin_items = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"batch",
|
||||
"override_settings",
|
||||
]
|
||||
|
||||
|
||||
def ui_reorder_categories():
|
||||
from modules import scripts
|
||||
|
||||
yield from ui_reorder_categories_builtin_items
|
||||
|
||||
sections = {}
|
||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||
if isinstance(script.section, str):
|
||||
sections[script.section] = 1
|
||||
|
||||
yield from sections
|
||||
|
||||
yield "scripts"
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
import typing
|
||||
import collections.abc as abc
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# Only import this when code is being type-checked, it doesn't have any effect at runtime
|
||||
from .processing import StableDiffusionProcessing
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
name: str
|
||||
@@ -37,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles):
|
||||
return prompt
|
||||
|
||||
|
||||
re_spaces = re.compile(" +")
|
||||
|
||||
|
||||
def extract_style_text_from_prompt(style_text, prompt):
|
||||
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
|
||||
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
|
||||
if "{prompt}" in stripped_style_text:
|
||||
left, right = stripped_style_text.split("{prompt}", 2)
|
||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
||||
return True, prompt
|
||||
else:
|
||||
if stripped_prompt.endswith(stripped_style_text):
|
||||
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
||||
|
||||
if prompt.endswith(', '):
|
||||
prompt = prompt[:-2]
|
||||
|
||||
return True, prompt
|
||||
|
||||
return False, prompt
|
||||
|
||||
|
||||
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||
if not style.prompt and not style.negative_prompt:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
||||
if not match_positive:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
||||
if not match_negative:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
return True, extracted_positive, extracted_negative
|
||||
|
||||
|
||||
class StyleDatabase:
|
||||
def __init__(self, path: str):
|
||||
self.no_style = PromptStyle("None", "", "")
|
||||
@@ -52,7 +82,7 @@ class StyleDatabase:
|
||||
return
|
||||
|
||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
reader = csv.DictReader(file)
|
||||
reader = csv.DictReader(file, skipinitialspace=True)
|
||||
for row in reader:
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||
@@ -76,10 +106,34 @@ class StyleDatabase:
|
||||
if os.path.exists(path):
|
||||
shutil.copy(path, f"{path}.bak")
|
||||
|
||||
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
||||
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||
|
||||
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
||||
extracted = []
|
||||
|
||||
applicable_styles = list(self.styles.values())
|
||||
|
||||
while True:
|
||||
found_style = None
|
||||
|
||||
for style in applicable_styles:
|
||||
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
|
||||
if is_match:
|
||||
found_style = style
|
||||
prompt = new_prompt
|
||||
negative_prompt = new_neg_prompt
|
||||
break
|
||||
|
||||
if not found_style:
|
||||
break
|
||||
|
||||
applicable_styles.remove(found_style)
|
||||
extracted.append(found_style.name)
|
||||
|
||||
return list(reversed(extracted)), prompt, negative_prompt
|
||||
|
||||
@@ -179,7 +179,7 @@ def efficient_dot_product_attention(
|
||||
chunk_idx,
|
||||
min(query_chunk_size, q_tokens)
|
||||
)
|
||||
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
@@ -201,14 +201,15 @@ def efficient_dot_product_attention(
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
|
||||
res = torch.zeros_like(query)
|
||||
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
||||
attn_scores = compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key=key,
|
||||
value=value,
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
)
|
||||
|
||||
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
||||
|
||||
return res
|
||||
|
||||
162
modules/sysinfo.py
Normal file
162
modules/sysinfo.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import platform
|
||||
import hashlib
|
||||
import pkg_resources
|
||||
import psutil
|
||||
import re
|
||||
|
||||
import launch
|
||||
from modules import paths_internal, timer
|
||||
|
||||
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
|
||||
environment_whitelist = {
|
||||
"GIT",
|
||||
"INDEX_URL",
|
||||
"WEBUI_LAUNCH_LIVE_OUTPUT",
|
||||
"GRADIO_ANALYTICS_ENABLED",
|
||||
"PYTHONPATH",
|
||||
"TORCH_INDEX_URL",
|
||||
"TORCH_COMMAND",
|
||||
"REQS_FILE",
|
||||
"XFORMERS_PACKAGE",
|
||||
"GFPGAN_PACKAGE",
|
||||
"CLIP_PACKAGE",
|
||||
"OPENCLIP_PACKAGE",
|
||||
"STABLE_DIFFUSION_REPO",
|
||||
"K_DIFFUSION_REPO",
|
||||
"CODEFORMER_REPO",
|
||||
"BLIP_REPO",
|
||||
"STABLE_DIFFUSION_COMMIT_HASH",
|
||||
"K_DIFFUSION_COMMIT_HASH",
|
||||
"CODEFORMER_COMMIT_HASH",
|
||||
"BLIP_COMMIT_HASH",
|
||||
"COMMANDLINE_ARGS",
|
||||
"IGNORE_CMD_ARGS_ERRORS",
|
||||
}
|
||||
|
||||
|
||||
def pretty_bytes(num, suffix="B"):
|
||||
for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
|
||||
if abs(num) < 1024 or unit == 'Y':
|
||||
return f"{num:.0f}{unit}{suffix}"
|
||||
num /= 1024
|
||||
|
||||
|
||||
def get():
|
||||
res = get_dict()
|
||||
|
||||
text = json.dumps(res, ensure_ascii=False, indent=4)
|
||||
|
||||
h = hashlib.sha256(text.encode("utf8"))
|
||||
text = text.replace(checksum_token, h.hexdigest())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
|
||||
|
||||
|
||||
def check(x):
|
||||
m = re.search(re_checksum, x)
|
||||
if not m:
|
||||
return False
|
||||
|
||||
replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
|
||||
|
||||
h = hashlib.sha256(replaced.encode("utf8"))
|
||||
return h.hexdigest() == m.group(1)
|
||||
|
||||
|
||||
def get_dict():
|
||||
ram = psutil.virtual_memory()
|
||||
|
||||
res = {
|
||||
"Platform": platform.platform(),
|
||||
"Python": platform.python_version(),
|
||||
"Version": launch.git_tag(),
|
||||
"Commit": launch.commit_hash(),
|
||||
"Script path": paths_internal.script_path,
|
||||
"Data path": paths_internal.data_path,
|
||||
"Extensions dir": paths_internal.extensions_dir,
|
||||
"Checksum": checksum_token,
|
||||
"Commandline": sys.argv,
|
||||
"Torch env info": get_torch_sysinfo(),
|
||||
"Exceptions": get_exceptions(),
|
||||
"CPU": {
|
||||
"model": platform.processor(),
|
||||
"count logical": psutil.cpu_count(logical=True),
|
||||
"count physical": psutil.cpu_count(logical=False),
|
||||
},
|
||||
"RAM": {
|
||||
x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
|
||||
},
|
||||
"Extensions": get_extensions(enabled=True),
|
||||
"Inactive extensions": get_extensions(enabled=False),
|
||||
"Environment": get_environment(),
|
||||
"Config": get_config(),
|
||||
"Startup": timer.startup_record,
|
||||
"Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
|
||||
}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def format_traceback(tb):
|
||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
from modules import errors
|
||||
|
||||
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def get_environment():
|
||||
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
||||
|
||||
|
||||
re_newline = re.compile(r"\r*\n")
|
||||
|
||||
|
||||
def get_torch_sysinfo():
|
||||
try:
|
||||
import torch.utils.collect_env
|
||||
info = torch.utils.collect_env.get_env_info()._asdict()
|
||||
|
||||
return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def get_extensions(*, enabled):
|
||||
|
||||
try:
|
||||
from modules import extensions
|
||||
|
||||
def to_json(x: extensions.Extension):
|
||||
return {
|
||||
"name": x.name,
|
||||
"path": x.path,
|
||||
"version": x.version,
|
||||
"branch": x.branch,
|
||||
"remote": x.remote,
|
||||
}
|
||||
|
||||
return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def get_config():
|
||||
try:
|
||||
from modules import shared
|
||||
return shared.opts.data
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
@@ -1,10 +1,8 @@
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from math import log, sqrt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL import ImageDraw
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
@@ -12,63 +10,64 @@ RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
results = []
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
results = []
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
return results
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
@@ -78,29 +77,29 @@ def focal_point(im, settings):
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if len(corner_points) > 0:
|
||||
if corner_points:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if len(entropy_points) > 0:
|
||||
if entropy_points:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if len(face_points) > 0:
|
||||
if face_points:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if len(corner_points) > 0:
|
||||
if corner_points:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if len(entropy_points) > 0:
|
||||
if entropy_points:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if len(face_points) > 0:
|
||||
if face_points:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
@@ -134,7 +133,7 @@ def focal_point(im, settings):
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
@@ -185,10 +184,10 @@ def image_face_points(im, settings):
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(faces) > 0:
|
||||
if faces:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
||||
return []
|
||||
@@ -262,10 +261,11 @@ def image_entropy(im):
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
@@ -283,59 +283,58 @@ def poi_average(pois, settings):
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
return w == h
|
||||
|
||||
|
||||
def download_and_cache_models(dirname):
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size//2,
|
||||
self.y - size//2,
|
||||
self.x + size//2,
|
||||
self.y + size//2
|
||||
]
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size // 2,
|
||||
self.y - size // 2,
|
||||
self.x + size // 2,
|
||||
self.y + size // 2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
|
||||
@@ -32,7 +32,7 @@ class DatasetEntry:
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
@@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
|
||||
weight = torch.ones(latent_sample.shape)
|
||||
else:
|
||||
weight = None
|
||||
|
||||
|
||||
if latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
else:
|
||||
@@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
||||
return BatchLoaderRandom(batch)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import zlib
|
||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||
from fonts.ttf import Roboto
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
@@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
||||
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||
|
||||
def object_hook(self, d):
|
||||
if 'TORCHTENSOR' in d:
|
||||
@@ -131,17 +131,17 @@ def extract_image_data_embed(image):
|
||||
|
||||
|
||||
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
|
||||
from modules.images import get_font
|
||||
if textfont:
|
||||
warnings.warn(
|
||||
'passing in a textfont to caption_image_overlay is deprecated and does nothing',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
from math import cos
|
||||
|
||||
image = srcimage.copy()
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
try:
|
||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
textfont = opts.font or Roboto
|
||||
except Exception:
|
||||
textfont = Roboto
|
||||
|
||||
factor = 1.5
|
||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||
for y in range(image.size[1]):
|
||||
@@ -152,12 +152,12 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
font = get_font(fontsize)
|
||||
padding = 10
|
||||
|
||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
font = get_font(fontsize)
|
||||
_, _, w, h = draw.textbbox((0, 0), title, font=font)
|
||||
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
@@ -168,7 +168,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
|
||||
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
|
||||
|
||||
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
|
||||
font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
|
||||
|
||||
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
|
||||
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
|
||||
|
||||
@@ -12,7 +12,7 @@ class LearnScheduleIterator:
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
try:
|
||||
for i, pair in enumerate(pairs):
|
||||
for pair in pairs:
|
||||
if not pair.strip():
|
||||
continue
|
||||
tmp = pair.split(':')
|
||||
@@ -32,8 +32,8 @@ class LearnScheduleIterator:
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
except (ValueError, AssertionError) as e:
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
@@ -2,11 +2,51 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
|
||||
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
|
||||
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
|
||||
saved_params_shared = {
|
||||
"batch_size",
|
||||
"clip_grad_mode",
|
||||
"clip_grad_value",
|
||||
"create_image_every",
|
||||
"data_root",
|
||||
"gradient_step",
|
||||
"initial_step",
|
||||
"latent_sampling_method",
|
||||
"learn_rate",
|
||||
"log_directory",
|
||||
"model_hash",
|
||||
"model_name",
|
||||
"num_of_dataset_images",
|
||||
"steps",
|
||||
"template_file",
|
||||
"training_height",
|
||||
"training_width",
|
||||
}
|
||||
saved_params_ti = {
|
||||
"embedding_name",
|
||||
"num_vectors_per_token",
|
||||
"save_embedding_every",
|
||||
"save_image_with_stored_embedding",
|
||||
}
|
||||
saved_params_hypernet = {
|
||||
"activation_func",
|
||||
"add_layer_norm",
|
||||
"hypernetwork_name",
|
||||
"layer_structure",
|
||||
"save_hypernetwork_every",
|
||||
"use_dropout",
|
||||
"weight_init",
|
||||
}
|
||||
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
|
||||
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
|
||||
saved_params_previews = {
|
||||
"preview_cfg_scale",
|
||||
"preview_height",
|
||||
"preview_negative_prompt",
|
||||
"preview_prompt",
|
||||
"preview_sampler_index",
|
||||
"preview_seed",
|
||||
"preview_steps",
|
||||
"preview_width",
|
||||
}
|
||||
|
||||
|
||||
def save_settings_to_file(log_directory, all_params):
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import os
|
||||
from PIL import Image, ImageOps
|
||||
import math
|
||||
import platform
|
||||
import sys
|
||||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import paths, shared, images, deepbooru
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
@@ -51,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if params.process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
if caption:
|
||||
caption += ", "
|
||||
caption += deepbooru.model.tag_multi(image)
|
||||
|
||||
@@ -71,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
if caption:
|
||||
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
@@ -129,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
|
||||
default=None
|
||||
)
|
||||
return wh and center_crop(image, *wh)
|
||||
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
width = process_width
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
@@ -15,7 +12,7 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@@ -30,7 +27,7 @@ textual_inversion_templates = {}
|
||||
def list_textual_inversion_templates():
|
||||
textual_inversion_templates.clear()
|
||||
|
||||
for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||
for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||
for fn in fns:
|
||||
path = os.path.join(root, fn)
|
||||
|
||||
@@ -121,16 +118,29 @@ class EmbeddingDatabase:
|
||||
self.embedding_dirs.clear()
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||
return self.register_embedding_by_name(embedding, model, embedding.name)
|
||||
|
||||
def register_embedding_by_name(self, embedding, model, name):
|
||||
ids = model.cond_stage_model.tokenize([name])[0]
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
self.ids_lookup[first_id] = []
|
||||
|
||||
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
||||
|
||||
if name in self.word_embeddings:
|
||||
# remove old one from the lookup list
|
||||
lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
|
||||
else:
|
||||
lookup = self.ids_lookup[first_id]
|
||||
if embedding is not None:
|
||||
lookup += [(ids, embedding)]
|
||||
self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
|
||||
if embedding is None:
|
||||
# unregister embedding with specified name
|
||||
if name in self.word_embeddings:
|
||||
del self.word_embeddings[name]
|
||||
if len(self.ids_lookup[first_id])==0:
|
||||
del self.ids_lookup[first_id]
|
||||
return None
|
||||
self.word_embeddings[name] = embedding
|
||||
return embedding
|
||||
|
||||
def get_expected_shape(self):
|
||||
@@ -167,8 +177,7 @@ class EmbeddingDatabase:
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
@@ -199,7 +208,7 @@ class EmbeddingDatabase:
|
||||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for root, _, fns in os.walk(embdir.path, followlinks=True):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
@@ -209,14 +218,13 @@ class EmbeddingDatabase:
|
||||
|
||||
self.load_from_file(fullfn, fn)
|
||||
except Exception:
|
||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error loading embedding {fn}", exc_info=True)
|
||||
continue
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||
if not force_reload:
|
||||
need_reload = False
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
for embdir in self.embedding_dirs.values():
|
||||
if embdir.has_changed():
|
||||
need_reload = True
|
||||
break
|
||||
@@ -229,7 +237,7 @@ class EmbeddingDatabase:
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
for embdir in self.embedding_dirs.values():
|
||||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
@@ -243,7 +251,7 @@ class EmbeddingDatabase:
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
if self.skipped_embeddings:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
|
||||
def find_embedding_at_position(self, tokens, offset):
|
||||
@@ -325,16 +333,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||
|
||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
scalar_value=value, global_step=step)
|
||||
|
||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||
# Convert a pil image to a torch tensor
|
||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
len(pil_image.getbands()))
|
||||
img_tensor = img_tensor.permute((2, 0, 1))
|
||||
|
||||
|
||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
@@ -404,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||
@@ -414,7 +422,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
@@ -441,7 +449,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
@@ -470,7 +478,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
for _ in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
@@ -487,7 +495,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if use_weight:
|
||||
@@ -515,7 +523,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||
|
||||
@@ -603,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
vectorSize = '?'
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
@@ -634,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pass
|
||||
errors.report("Error training embedding", exc_info=True)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
|
||||
@@ -1,11 +1,30 @@
|
||||
import time
|
||||
|
||||
|
||||
class TimerSubcategory:
|
||||
def __init__(self, timer, category):
|
||||
self.timer = timer
|
||||
self.category = category
|
||||
self.start = None
|
||||
self.original_base_category = timer.base_category
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
elapsed_for_subcategroy = time.time() - self.start
|
||||
self.timer.base_category = self.original_base_category
|
||||
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
||||
self.timer.record(self.category)
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
self.base_category = ''
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
@@ -13,18 +32,29 @@ class Timer:
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
def add_time_to_record(self, category, amount):
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += e + extra_time
|
||||
self.records[category] += amount
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
|
||||
self.add_time_to_record(self.base_category + category, e + extra_time)
|
||||
|
||||
self.total += e + extra_time
|
||||
|
||||
def subcategory(self, name):
|
||||
self.elapsed()
|
||||
|
||||
subcat = TimerSubcategory(self, name)
|
||||
return subcat
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
@@ -34,5 +64,13 @@ class Timer:
|
||||
|
||||
return res
|
||||
|
||||
def dump(self):
|
||||
return {'total': self.total, 'records': self.records}
|
||||
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
startup_timer = Timer()
|
||||
|
||||
startup_record = None
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
import modules.scripts
|
||||
from modules import sd_samplers
|
||||
from modules import sd_samplers, processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
@@ -41,19 +39,24 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
hr_resize_x=hr_resize_x,
|
||||
hr_resize_y=hr_resize_y,
|
||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
||||
hr_prompt=hr_prompt,
|
||||
hr_negative_prompt=hr_negative_prompt,
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.script_args = args
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
processed = processing.process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
638
modules/ui.py
638
modules/ui.py
@@ -1,29 +1,25 @@
|
||||
import html
|
||||
import datetime
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial, reduce
|
||||
from functools import reduce
|
||||
import warnings
|
||||
|
||||
import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
|
||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_gradio_extensions import reload_javascript
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
@@ -34,7 +30,6 @@ import modules.shared as shared
|
||||
import modules.styles
|
||||
import modules.textual_inversion.ui
|
||||
from modules import prompt_parser
|
||||
from modules.images import save_image
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.textual_inversion import textual_inversion
|
||||
@@ -42,6 +37,8 @@ import modules.hypernetworks.ui
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
import modules.extras
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
|
||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
@@ -59,7 +56,7 @@ if cmd_opts.ngrok is not None:
|
||||
ngrok.connect(
|
||||
cmd_opts.ngrok,
|
||||
cmd_opts.port if cmd_opts.port is not None else 7860,
|
||||
cmd_opts.ngrok_region
|
||||
cmd_opts.ngrok_options
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +79,8 @@ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
restore_progress_symbol = '\U0001F300' # 🌀
|
||||
detect_image_size_symbol = '\U0001F4D0' # 📐
|
||||
up_down_symbol = '\u2195\ufe0f' # ↕️
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
@@ -93,16 +92,6 @@ def send_gradio_gallery_to_image(x):
|
||||
return None
|
||||
return image_from_url_text(x[0])
|
||||
|
||||
def visit(x, func, path=""):
|
||||
if hasattr(x, 'children'):
|
||||
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
||||
# Tabs element can't have a label, have to use elem_id instead
|
||||
func(f"{path}/Tabs@{x.elem_id}", x)
|
||||
for c in x.children:
|
||||
visit(c, func, path)
|
||||
elif x.label is not None:
|
||||
func(f"{path}/{x.label}", x)
|
||||
|
||||
|
||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||
if name is None:
|
||||
@@ -166,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
||||
img = Image.open(image)
|
||||
filename = os.path.basename(image)
|
||||
left, _ = os.path.splitext(filename)
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
|
||||
|
||||
return [gr.update(), None]
|
||||
|
||||
@@ -206,8 +195,8 @@ def create_seed_inputs(target_interface):
|
||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
|
||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
|
||||
|
||||
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
||||
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
||||
random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
|
||||
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
|
||||
|
||||
def change_visibility(show):
|
||||
return {comp: gr_show(show) for comp in seed_extras}
|
||||
@@ -246,10 +235,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
all_seeds = gen_info.get('all_seeds', [-1])
|
||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
if gen_info_string != '':
|
||||
print("Error parsing JSON generation info:", file=sys.stderr)
|
||||
print(gen_info_string, file=sys.stderr)
|
||||
except json.decoder.JSONDecodeError:
|
||||
if gen_info_string:
|
||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||
|
||||
return [res, gr_show(False)]
|
||||
|
||||
@@ -288,12 +276,12 @@ def create_toprow(is_img2img):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
@@ -384,25 +372,6 @@ def apply_setting(key, value):
|
||||
return getattr(opts, key)
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||
|
||||
for k, v in args.items():
|
||||
setattr(refresh_component, k, v)
|
||||
|
||||
return gr.update(**(args or {}))
|
||||
|
||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button.click(
|
||||
fn=refresh,
|
||||
inputs=[],
|
||||
outputs=[refresh_component]
|
||||
)
|
||||
return refresh_button
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
return ui_common.create_output_panel(tabname, outdir)
|
||||
|
||||
@@ -421,27 +390,17 @@ def create_sampler_and_steps_selection(choices, tabname):
|
||||
|
||||
|
||||
def ordered_ui_categories():
|
||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
|
||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
||||
|
||||
for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||
for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||
yield category
|
||||
|
||||
|
||||
def get_value_for_setting(key):
|
||||
value = getattr(opts, key)
|
||||
|
||||
info = opts.data_labels[key]
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||
|
||||
return gr.update(value=value, **args)
|
||||
|
||||
|
||||
def create_override_settings_dropdown(tabname, row):
|
||||
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
||||
|
||||
dropdown.change(
|
||||
fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
|
||||
fn=lambda x: gr.Dropdown.update(visible=bool(x)),
|
||||
inputs=[dropdown],
|
||||
outputs=[dropdown],
|
||||
)
|
||||
@@ -472,6 +431,8 @@ def create_ui():
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
modules.scripts.scripts_txt2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
|
||||
@@ -515,6 +476,17 @@ def create_ui():
|
||||
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||
|
||||
elif category == "batch":
|
||||
if not opts.dimensions_and_batch_together:
|
||||
with FormRow(elem_id="txt2img_column_batch"):
|
||||
@@ -529,15 +501,21 @@ def create_ui():
|
||||
with FormGroup(elem_id="txt2img_script_container"):
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
else:
|
||||
modules.scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||
|
||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||
for input in hr_resolution_preview_inputs:
|
||||
input.change(
|
||||
|
||||
for component in hr_resolution_preview_inputs:
|
||||
event = component.release if isinstance(component, gr.Slider) else component.change
|
||||
|
||||
event(
|
||||
fn=calc_resolution_hires,
|
||||
inputs=hr_resolution_preview_inputs,
|
||||
outputs=[hr_final_resolution],
|
||||
show_progress=False,
|
||||
)
|
||||
input.change(
|
||||
event(
|
||||
None,
|
||||
_js="onCalcResolutionHires",
|
||||
inputs=hr_resolution_preview_inputs,
|
||||
@@ -576,7 +554,11 @@ def create_ui():
|
||||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
hr_sampler_index,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
override_settings,
|
||||
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
@@ -591,7 +573,7 @@ def create_ui():
|
||||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
@@ -614,7 +596,8 @@ def create_ui():
|
||||
outputs=[
|
||||
txt2img_prompt,
|
||||
txt_prompt_img
|
||||
]
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
enable_hr.change(
|
||||
@@ -639,6 +622,7 @@ def create_ui():
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(enable_hr, lambda d: "Denoising strength" in d),
|
||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||
@@ -647,6 +631,11 @@ def create_ui():
|
||||
(hr_second_pass_steps, "Hires steps"),
|
||||
(hr_resize_x, "Hires resize-1"),
|
||||
(hr_resize_y, "Hires resize-2"),
|
||||
(hr_sampler_index, "Hires sampler"),
|
||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
|
||||
(hr_prompt, "Hires prompt"),
|
||||
(hr_negative_prompt, "Hires negative prompt"),
|
||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||
@@ -704,19 +693,19 @@ def create_ui():
|
||||
img2img_selected_tab = gr.State(0)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('img2img', init_img)
|
||||
|
||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('sketch', sketch)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
||||
inpaint_color_sketch_orig = gr.State(None)
|
||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||
|
||||
@@ -736,17 +725,20 @@ def create_ui():
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(
|
||||
f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||
f"{hidden}</p>"
|
||||
)
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
with gr.Accordion("PNG info", open=False):
|
||||
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch]
|
||||
|
||||
for i, tab in enumerate(img2img_tabs):
|
||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||
@@ -773,6 +765,8 @@ def create_ui():
|
||||
with FormRow():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
modules.scripts.scripts_img2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "sampler":
|
||||
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
|
||||
@@ -783,15 +777,16 @@ def create_ui():
|
||||
selected_scale_tab = gr.State(value=0)
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
||||
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||
with FormRow():
|
||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||
|
||||
with gr.Tab(label="Resize by") as tab_scale_by:
|
||||
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||
|
||||
with FormRow():
|
||||
@@ -881,6 +876,8 @@ def create_ui():
|
||||
inputs=[],
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
else:
|
||||
modules.scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
|
||||
@@ -895,7 +892,8 @@ def create_ui():
|
||||
outputs=[
|
||||
img2img_prompt,
|
||||
img2img_prompt_img
|
||||
]
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
img2img_args = dict(
|
||||
@@ -940,6 +938,9 @@ def create_ui():
|
||||
img2img_batch_output_dir,
|
||||
img2img_batch_inpaint_mask_dir,
|
||||
override_settings,
|
||||
img2img_batch_use_png_info,
|
||||
img2img_batch_png_info_props,
|
||||
img2img_batch_png_info_dir,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
@@ -967,7 +968,16 @@ def create_ui():
|
||||
|
||||
img2img_prompt.submit(**img2img_args)
|
||||
submit.click(**img2img_args)
|
||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||
|
||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
detect_image_size_btn.click(
|
||||
fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
|
||||
_js="currentImg2imgSourceResolution",
|
||||
inputs=[dummy_component, dummy_component, dummy_component],
|
||||
outputs=[width, height],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
@@ -1035,6 +1045,7 @@ def create_ui():
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(mask_blur, "Mask blur"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
@@ -1189,7 +1200,7 @@ def create_ui():
|
||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
||||
|
||||
|
||||
with gr.Column(visible=False) as process_multicrop_col:
|
||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
||||
with gr.Row():
|
||||
@@ -1201,7 +1212,7 @@ def create_ui():
|
||||
with gr.Row():
|
||||
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
||||
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
@@ -1230,7 +1241,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
def get_textual_inversion_template_names():
|
||||
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
||||
return sorted(textual_inversion.textual_inversion_templates)
|
||||
|
||||
with gr.Tab(label="Train", id="train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
@@ -1238,13 +1249,13 @@ def create_ui():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
||||
|
||||
with FormRow():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||
|
||||
|
||||
with FormRow():
|
||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||
@@ -1290,8 +1301,8 @@ def create_ui():
|
||||
|
||||
with gr.Column(elem_id='ti_gallery_container'):
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||
gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
|
||||
create_embedding.click(
|
||||
@@ -1444,194 +1455,10 @@ def create_ui():
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_setting_component(key, is_quicksettings=False):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||
|
||||
info = opts.data_labels[key]
|
||||
t = type(info.default)
|
||||
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args
|
||||
|
||||
if info.component is not None:
|
||||
comp = info.component
|
||||
elif t == str:
|
||||
comp = gr.Textbox
|
||||
elif t == int:
|
||||
comp = gr.Number
|
||||
elif t == bool:
|
||||
comp = gr.Checkbox
|
||||
else:
|
||||
raise Exception(f'bad options item type: {t} for key {key}')
|
||||
|
||||
elem_id = f"setting_{key}"
|
||||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
with FormRow():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
|
||||
return res
|
||||
|
||||
components = []
|
||||
component_dict = {}
|
||||
shared.settings_components = component_dict
|
||||
|
||||
script_callbacks.ui_settings_callback()
|
||||
opts.reorder()
|
||||
|
||||
def run_settings(*args):
|
||||
changed = []
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||
if comp == dummy_component:
|
||||
continue
|
||||
|
||||
if opts.set(key, value):
|
||||
changed.append(key)
|
||||
|
||||
try:
|
||||
opts.save(shared.config_filename)
|
||||
except RuntimeError:
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
|
||||
|
||||
def run_settings_single(value, key):
|
||||
if not opts.same_type(value, opts.data_labels[key].default):
|
||||
return gr.update(visible=True), opts.dumpjson()
|
||||
|
||||
if not opts.set(key, value):
|
||||
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
|
||||
return get_value_for_setting(key), opts.dumpjson()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=6):
|
||||
settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
||||
with gr.Column():
|
||||
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
||||
|
||||
result = gr.HTML(elem_id="settings_result")
|
||||
|
||||
quicksettings_names = opts.quicksettings_list
|
||||
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
|
||||
|
||||
quicksettings_list = []
|
||||
|
||||
previous_section = None
|
||||
current_tab = None
|
||||
current_row = None
|
||||
with gr.Tabs(elem_id="settings"):
|
||||
for i, (k, item) in enumerate(opts.data_labels.items()):
|
||||
section_must_be_skipped = item.section[0] is None
|
||||
|
||||
if previous_section != item.section and not section_must_be_skipped:
|
||||
elem_id, text = item.section
|
||||
|
||||
if current_tab is not None:
|
||||
current_row.__exit__()
|
||||
current_tab.__exit__()
|
||||
|
||||
gr.Group()
|
||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||
current_tab.__enter__()
|
||||
current_row = gr.Column(variant='compact')
|
||||
current_row.__enter__()
|
||||
|
||||
previous_section = item.section
|
||||
|
||||
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
|
||||
quicksettings_list.append((i, k, item))
|
||||
components.append(dummy_component)
|
||||
elif section_must_be_skipped:
|
||||
components.append(dummy_component)
|
||||
else:
|
||||
component = create_setting_component(k)
|
||||
component_dict[k] = component
|
||||
components.append(component)
|
||||
|
||||
if current_tab is not None:
|
||||
current_row.__exit__()
|
||||
current_tab.__exit__()
|
||||
|
||||
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
|
||||
def unload_sd_weights():
|
||||
modules.sd_models.unload_model_weights()
|
||||
|
||||
def reload_sd_weights():
|
||||
modules.sd_models.reload_model_weights()
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=unload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=reload_sd_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='function(){}'
|
||||
)
|
||||
|
||||
download_localization.click(
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='download_localization'
|
||||
)
|
||||
|
||||
def reload_scripts():
|
||||
modules.scripts.reload_script_body_only()
|
||||
reload_javascript() # need to refresh the html page
|
||||
|
||||
reload_script_bodies.click(
|
||||
fn=reload_scripts,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
def request_restart():
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
fn=request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
settings = ui_settings.UiSettings()
|
||||
settings.create_ui(loadsave, dummy_component)
|
||||
|
||||
interfaces = [
|
||||
(txt2img_interface, "txt2img", "txt2img"),
|
||||
@@ -1639,11 +1466,11 @@ def create_ui():
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(train_interface, "Train", "ti"),
|
||||
(train_interface, "Train", "train"),
|
||||
]
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
interfaces += [(settings_interface, "Settings", "settings")]
|
||||
interfaces += [(settings.interface, "Settings", "settings")]
|
||||
|
||||
extensions_interface = ui_extensions.create_ui()
|
||||
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
||||
@@ -1653,76 +1480,48 @@ def create_ui():
|
||||
shared.tab_names.append(label)
|
||||
|
||||
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
component_dict[k] = component
|
||||
settings.add_quicksettings()
|
||||
|
||||
parameters_copypaste.connect_paste_params_buttons()
|
||||
|
||||
with gr.Tabs(elem_id="tabs") as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
|
||||
sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
|
||||
|
||||
for interface, label, ifid in sorted_interfaces:
|
||||
if label in shared.opts.hidden_tabs:
|
||||
continue
|
||||
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
|
||||
interface.render()
|
||||
|
||||
for interface, _label, ifid in interfaces:
|
||||
if ifid in ["extensions", "settings"]:
|
||||
continue
|
||||
|
||||
loadsave.add_block(interface, ifid)
|
||||
|
||||
loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
|
||||
|
||||
loadsave.setup_ui()
|
||||
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
|
||||
footer = shared.html("footer.html")
|
||||
footer = footer.format(versions=versions_html())
|
||||
footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
|
||||
gr.HTML(footer, elem_id="footer")
|
||||
|
||||
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||
settings_submit.click(
|
||||
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
|
||||
inputs=components,
|
||||
outputs=[text_settings, result],
|
||||
)
|
||||
|
||||
for i, k, item in quicksettings_list:
|
||||
component = component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||
change_handler(
|
||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
settings.add_functionality(demo)
|
||||
|
||||
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||
text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[component_dict['sd_model_checkpoint'], dummy_component],
|
||||
outputs=[component_dict['sd_model_checkpoint'], text_settings],
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||
|
||||
def get_settings_values():
|
||||
return [get_value_for_setting(key) for key in component_keys]
|
||||
|
||||
demo.load(
|
||||
fn=get_settings_values,
|
||||
inputs=[],
|
||||
outputs=[component_dict[k] for k in component_keys],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report("Error loading/saving model file", exc_info=True)
|
||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||
return results
|
||||
@@ -1750,102 +1549,13 @@ def create_ui():
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
tertiary_model_name,
|
||||
component_dict['sd_model_checkpoint'],
|
||||
settings.component_dict['sd_model_checkpoint'],
|
||||
modelmerger_result,
|
||||
]
|
||||
)
|
||||
|
||||
ui_config_file = cmd_opts.ui_config_file
|
||||
ui_settings = {}
|
||||
settings_count = len(ui_settings)
|
||||
error_loading = False
|
||||
|
||||
try:
|
||||
if os.path.exists(ui_config_file):
|
||||
with open(ui_config_file, "r", encoding="utf8") as file:
|
||||
ui_settings = json.load(file)
|
||||
except Exception:
|
||||
error_loading = True
|
||||
print("Error loading settings:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def loadsave(path, x):
|
||||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = f"{path}/{field}"
|
||||
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
saved_value = ui_settings.get(key, None)
|
||||
if saved_value is None:
|
||||
ui_settings[key] = getattr(obj, field)
|
||||
elif condition and not condition(saved_value):
|
||||
pass
|
||||
|
||||
# this warning is generally not useful;
|
||||
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
||||
else:
|
||||
setattr(obj, field, saved_value)
|
||||
if init_field is not None:
|
||||
init_field(saved_value)
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
|
||||
if type(x) == gr.Slider:
|
||||
apply_field(x, 'value')
|
||||
apply_field(x, 'minimum')
|
||||
apply_field(x, 'maximum')
|
||||
apply_field(x, 'step')
|
||||
|
||||
if type(x) == gr.Radio:
|
||||
apply_field(x, 'value', lambda val: val in x.choices)
|
||||
|
||||
if type(x) == gr.Checkbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Textbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Number:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Dropdown:
|
||||
def check_dropdown(val):
|
||||
if getattr(x, 'multiselect', False):
|
||||
return all([value in x.choices for value in val])
|
||||
else:
|
||||
return val in x.choices
|
||||
|
||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||
|
||||
def check_tab_id(tab_id):
|
||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||
if type(tab_id) == str:
|
||||
tab_ids = [t.id for t in tab_items]
|
||||
return tab_id in tab_ids
|
||||
elif type(tab_id) == int:
|
||||
return tab_id >= 0 and tab_id < len(tab_items)
|
||||
else:
|
||||
return False
|
||||
|
||||
if type(x) == gr.Tabs:
|
||||
apply_field(x, 'selected', check_tab_id)
|
||||
|
||||
visit(txt2img_interface, loadsave, "txt2img")
|
||||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||
visit(train_interface, loadsave, "train")
|
||||
|
||||
loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
|
||||
|
||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||
json.dump(ui_settings, file, indent=4)
|
||||
loadsave.dump_defaults()
|
||||
demo.ui_loadsave = loadsave
|
||||
|
||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||
interp_description.value = update_interp_description(interp_method.value)
|
||||
@@ -1853,70 +1563,6 @@ def create_ui():
|
||||
return demo
|
||||
|
||||
|
||||
def webpath(fn):
|
||||
if fn.startswith(script_path):
|
||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
||||
else:
|
||||
web_path = os.path.abspath(fn)
|
||||
|
||||
return f'file={web_path}?{os.path.getmtime(fn)}'
|
||||
|
||||
|
||||
def javascript_html():
|
||||
# Ensure localization is in `window` before scripts
|
||||
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
||||
|
||||
script_js = os.path.join(script_path, "script.js")
|
||||
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
if cmd_opts.theme:
|
||||
head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def css_html():
|
||||
head = ""
|
||||
|
||||
def stylesheet(fn):
|
||||
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
||||
|
||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
|
||||
head += stylesheet(cssfile)
|
||||
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
head += stylesheet(os.path.join(data_path, "user.css"))
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def reload_javascript():
|
||||
js = javascript_html()
|
||||
css = css_html()
|
||||
|
||||
def template_response(*args, **kwargs):
|
||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
||||
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
|
||||
gradio.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
|
||||
|
||||
|
||||
def versions_html():
|
||||
import torch
|
||||
import launch
|
||||
@@ -1933,15 +1579,15 @@ def versions_html():
|
||||
|
||||
return f"""
|
||||
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
||||
•
|
||||
 • 
|
||||
python: <span title="{sys.version}">{python_version}</span>
|
||||
•
|
||||
 • 
|
||||
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
||||
•
|
||||
 • 
|
||||
xformers: {xformers_version}
|
||||
•
|
||||
 • 
|
||||
gradio: {gr.__version__}
|
||||
•
|
||||
 • 
|
||||
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||
"""
|
||||
|
||||
@@ -1960,3 +1606,17 @@ def setup_ui_api(app):
|
||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
||||
|
||||
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||
|
||||
app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
|
||||
|
||||
def download_sysinfo(attachment=False):
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
||||
|
||||
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
||||
|
||||
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
|
||||
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
|
||||
|
||||
|
||||
@@ -10,8 +10,11 @@ import subprocess as sp
|
||||
from modules import call_queue, shared
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
import modules.images
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
||||
|
||||
def update_generation_info(generation_info, html_info, img_index):
|
||||
@@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||
extension: str = shared.opts.samples_format
|
||||
start_index = 0
|
||||
only_one = False
|
||||
|
||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||
|
||||
only_one = True
|
||||
images = [images[index]]
|
||||
start_index = index
|
||||
|
||||
@@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
is_grid = image_index < p.index_of_first_image
|
||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||
|
||||
p.batch_index = image_index-1
|
||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||
|
||||
filename = os.path.relpath(fullfn, path)
|
||||
@@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
|
||||
# Make Zip
|
||||
if do_make_zip:
|
||||
zip_filepath = os.path.join(path, "images.zip")
|
||||
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
||||
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
||||
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
||||
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
||||
|
||||
from zipfile import ZipFile
|
||||
with ZipFile(zip_filepath, "w") as zip_file:
|
||||
@@ -211,3 +219,23 @@ Requested path was: {f}
|
||||
))
|
||||
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||
|
||||
for k, v in args.items():
|
||||
setattr(refresh_component, k, v)
|
||||
|
||||
return gr.update(**(args or {}))
|
||||
|
||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button.click(
|
||||
fn=refresh,
|
||||
inputs=[],
|
||||
outputs=[refresh_component]
|
||||
)
|
||||
return refresh_button
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import json
|
||||
import os.path
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
@@ -12,7 +11,7 @@ import html
|
||||
import shutil
|
||||
import errno
|
||||
|
||||
from modules import extensions, shared, paths, config_states
|
||||
from modules import extensions, shared, paths, config_states, errors, restart
|
||||
from modules.paths_internal import config_states_dir
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
@@ -45,15 +44,16 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
||||
try:
|
||||
ext.fetch_and_reset_hard()
|
||||
except Exception:
|
||||
print(f"Error getting updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error getting updates for {ext.name}", exc_info=True)
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.disable_all_extensions = disable_all
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
if restart.is_restartable():
|
||||
restart.restart_program()
|
||||
else:
|
||||
restart.stop_program()
|
||||
|
||||
|
||||
def save_config_state(name):
|
||||
@@ -91,8 +91,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
|
||||
if restore_type == "webui" or restore_type == "both":
|
||||
config_states.restore_webui_config(config_state)
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
shared.state.request_restart()
|
||||
|
||||
return ""
|
||||
|
||||
@@ -115,8 +114,7 @@ def check_updates(id_task, disable_list):
|
||||
if 'FETCH_HEAD' not in str(e):
|
||||
raise
|
||||
except Exception:
|
||||
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
@@ -127,7 +125,9 @@ def make_commit_link(commit_hash, remote, text=None):
|
||||
if text is None:
|
||||
text = commit_hash[:8]
|
||||
if remote.startswith("https://github.com/"):
|
||||
href = os.path.join(remote, "commit", commit_hash)
|
||||
if remote.endswith(".git"):
|
||||
remote = remote[:-4]
|
||||
href = remote + "/commit/" + commit_hash
|
||||
return f'<a href="{href}" target="_blank">{text}</a>'
|
||||
else:
|
||||
return text
|
||||
@@ -138,9 +138,14 @@ def extension_table():
|
||||
<table id="extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>
|
||||
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
||||
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
||||
</th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Extension version">Version</abbr></th>
|
||||
<th>Branch</th>
|
||||
<th>Version</th>
|
||||
<th>Date</th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
@@ -148,6 +153,7 @@ def extension_table():
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext: extensions.Extension
|
||||
ext.read_info_from_repo()
|
||||
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
@@ -167,9 +173,11 @@ def extension_table():
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
||||
<td>{remote}</td>
|
||||
<td>{ext.branch}</td>
|
||||
<td>{version_link}</td>
|
||||
<td>{time.asctime(time.gmtime(ext.commit_date))}</td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
@@ -320,6 +328,11 @@ def normalize_git_url(url):
|
||||
def install_extension_from_url(dirname, url, branch_name=None):
|
||||
check_access()
|
||||
|
||||
if isinstance(dirname, str):
|
||||
dirname = dirname.strip()
|
||||
if isinstance(url, str):
|
||||
url = url.strip()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
|
||||
if dirname is None or dirname == "":
|
||||
@@ -332,7 +345,8 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||
|
||||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
|
||||
raise Exception(f'Extension with this URL is already installed: {url}')
|
||||
|
||||
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
||||
|
||||
@@ -340,12 +354,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
||||
shutil.rmtree(tmpdir, True)
|
||||
if not branch_name:
|
||||
# if no branch is specified, use the default branch
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
else:
|
||||
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
|
||||
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
@@ -410,9 +424,19 @@ sort_ordering = [
|
||||
(False, lambda x: x.get('name', 'z')),
|
||||
(True, lambda x: x.get('name', 'z')),
|
||||
(False, lambda x: 'z'),
|
||||
(True, lambda x: x.get('commit_time', '')),
|
||||
(True, lambda x: x.get('created_at', '')),
|
||||
(True, lambda x: x.get('stars', 0)),
|
||||
]
|
||||
|
||||
|
||||
def get_date(info: dict, key):
|
||||
try:
|
||||
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
|
||||
except (ValueError, TypeError):
|
||||
return ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
@@ -437,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
|
||||
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
||||
name = ext.get("name", "noname")
|
||||
stars = int(ext.get("stars", 0))
|
||||
added = ext.get('added', 'unknown')
|
||||
update_time = get_date(ext, 'commit_time')
|
||||
create_time = get_date(ext, 'created_at')
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
extension_tags = ext.get("tags", [])
|
||||
@@ -448,7 +475,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
||||
|
||||
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
|
||||
if any(x for x in extension_tags if x in tags_to_hide):
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
@@ -464,10 +491,11 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||
<td>{html.escape(description)}<p class="info">
|
||||
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
|
||||
|
||||
"""
|
||||
|
||||
for tag in [x for x in extension_tags if x not in tags]:
|
||||
@@ -484,17 +512,31 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
return code, list(tags)
|
||||
|
||||
|
||||
def preload_extensions_git_metadata():
|
||||
t0 = time.time()
|
||||
for extension in extensions.extensions:
|
||||
extension.read_info_from_repo()
|
||||
print(
|
||||
f"preload_extensions_git_metadata for "
|
||||
f"{len(extensions.extensions)} extensions took "
|
||||
f"{time.time() - t0:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
config_states.list_config_states()
|
||||
|
||||
threading.Thread(target=preload_extensions_git_metadata).start()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.Tabs(elem_id="tabs_extensions"):
|
||||
with gr.TabItem("Installed", id="installed"):
|
||||
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
|
||||
apply = gr.Button(value=apply_label, variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
@@ -508,7 +550,8 @@ def create_ui():
|
||||
</span>
|
||||
"""
|
||||
info = gr.HTML(html)
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
extensions_table = gr.HTML('Loading...')
|
||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
@@ -533,18 +576,18 @@ def create_ui():
|
||||
|
||||
with gr.Row():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags, sort_column],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
@@ -579,9 +622,9 @@ def create_ui():
|
||||
install_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
install_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[install_dirname, install_url, install_branch],
|
||||
outputs=[extensions_table, install_result],
|
||||
outputs=[install_url, extensions_table, install_result],
|
||||
)
|
||||
|
||||
with gr.TabItem("Backup/Restore"):
|
||||
@@ -595,7 +638,8 @@ def create_ui():
|
||||
config_save_button = gr.Button(value="Save Current Config")
|
||||
|
||||
config_states_info = gr.HTML("")
|
||||
config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
|
||||
config_states_table = gr.HTML("Loading...")
|
||||
ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
|
||||
|
||||
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
|
||||
|
||||
@@ -608,4 +652,5 @@ def create_ui():
|
||||
outputs=[config_states_table],
|
||||
)
|
||||
|
||||
|
||||
return ui
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from modules import shared
|
||||
from modules.images import read_info_from_image
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
from modules.ui import up_down_symbol
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
@@ -27,12 +26,12 @@ def register_page(page):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
@@ -91,8 +90,8 @@ class ExtraNetworksPage:
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for root, dirs, files in os.walk(parentdir):
|
||||
for dirname in dirs:
|
||||
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
||||
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
||||
x = os.path.join(root, dirname)
|
||||
|
||||
if not os.path.isdir(x):
|
||||
@@ -106,6 +105,9 @@ class ExtraNetworksPage:
|
||||
if not is_empty and not subdir.endswith("/"):
|
||||
subdir = subdir + "/"
|
||||
|
||||
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
||||
continue
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
@@ -148,6 +150,10 @@ class ExtraNetworksPage:
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
"""
|
||||
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
||||
"""
|
||||
|
||||
preview = item.get("preview", None)
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
@@ -156,7 +162,7 @@ class ExtraNetworksPage:
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
|
||||
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
@@ -170,12 +176,21 @@ class ExtraNetworksPage:
|
||||
if filename.startswith(absdir):
|
||||
local_path = filename[len(absdir):]
|
||||
|
||||
# if this is true, the item must not be show in the default view, and must instead only be
|
||||
# if this is true, the item must not be shown in the default view, and must instead only be
|
||||
# shown when searching for it
|
||||
serach_only = "/." in local_path or "\\." in local_path
|
||||
if shared.opts.extra_networks_hidden_models == "Always":
|
||||
search_only = False
|
||||
else:
|
||||
search_only = "/." in local_path or "\\." in local_path
|
||||
|
||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||
return ""
|
||||
|
||||
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
|
||||
|
||||
args = {
|
||||
"style": f"'display: none; {height}{width}{background_image}'",
|
||||
"background_image": background_image,
|
||||
"style": f"'display: none; {height}{width}'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
@@ -185,17 +200,30 @@ class ExtraNetworksPage:
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
"metadata_button": metadata_button,
|
||||
"serach_only": " search_only" if serach_only else "",
|
||||
"search_only": " search_only" if search_only else "",
|
||||
"sort_keys": sort_keys,
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
|
||||
def get_sort_keys(self, path):
|
||||
"""
|
||||
List of default keys used for sorting in the UI.
|
||||
"""
|
||||
pth = Path(path)
|
||||
stat = pth.stat()
|
||||
return {
|
||||
"date_created": int(stat.st_ctime or 0),
|
||||
"date_modified": int(stat.st_mtime or 0),
|
||||
"name": pth.name.lower(),
|
||||
}
|
||||
|
||||
def find_preview(self, path):
|
||||
"""
|
||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||
"""
|
||||
|
||||
preview_extensions = ["png", "jpg", "webp"]
|
||||
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
||||
if shared.opts.samples_format not in preview_extensions:
|
||||
preview_extensions.append(shared.opts.samples_format)
|
||||
|
||||
@@ -220,10 +248,19 @@ class ExtraNetworksPage:
|
||||
return None
|
||||
|
||||
|
||||
def intialize():
|
||||
def initialize():
|
||||
extra_pages.clear()
|
||||
|
||||
|
||||
def register_default_pages():
|
||||
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
|
||||
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
|
||||
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
|
||||
register_page(ExtraNetworksPageTextualInversion())
|
||||
register_page(ExtraNetworksPageHypernetworks())
|
||||
register_page(ExtraNetworksPageCheckpoints())
|
||||
|
||||
|
||||
class ExtraNetworksUi:
|
||||
def __init__(self):
|
||||
self.pages = None
|
||||
@@ -263,18 +300,20 @@ def create_ui(container, button, tabname):
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
||||
for page in ui.stored_extra_pages:
|
||||
page_id = page.title.lower().replace(" ", "_")
|
||||
|
||||
with gr.Tab(page.title, id=page_id):
|
||||
elem_id = f"{tabname}_{page_id}_cards_html"
|
||||
page_elem = gr.HTML('', elem_id=elem_id)
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
|
||||
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
|
||||
gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
@@ -283,13 +322,24 @@ def create_ui(container, button, tabname):
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
|
||||
if is_visible and not ui.pages_contents:
|
||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
||||
|
||||
def fill_tabs(is_empty):
|
||||
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
|
||||
|
||||
if not ui.pages_contents:
|
||||
refresh()
|
||||
|
||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
|
||||
if is_empty:
|
||||
return True, *ui.pages_contents
|
||||
|
||||
return True, *[gr.update() for _ in ui.pages_contents]
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
|
||||
|
||||
state_empty = gr.State(value=True)
|
||||
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
|
||||
|
||||
def refresh():
|
||||
for pg in ui.stored_extra_pages:
|
||||
@@ -327,18 +377,13 @@ def setup_ui(ui, gallery):
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||
if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
|
||||
is_allowed = True
|
||||
break
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
if geninfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
pnginfo_data.add_text('parameters', geninfo)
|
||||
image.save(filename, pnginfo=pnginfo_data)
|
||||
else:
|
||||
image.save(filename)
|
||||
save_image_with_geninfo(image, geninfo, filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
def list_items(self):
|
||||
checkpoint: sd_models.CheckpointInfo
|
||||
for name, checkpoint in sd_models.checkpoints_list.items():
|
||||
for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
yield {
|
||||
"name": checkpoint.name_for_extra,
|
||||
@@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
|
||||
@@ -12,7 +12,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
def list_items(self):
|
||||
for name, path in shared.hypernetworks.items():
|
||||
for index, (name, path) in enumerate(shared.hypernetworks.items()):
|
||||
path, ext = os.path.splitext(path)
|
||||
|
||||
yield {
|
||||
@@ -23,6 +23,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
|
||||
@@ -13,7 +13,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def list_items(self):
|
||||
for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
|
||||
for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
yield {
|
||||
"name": embedding.name,
|
||||
@@ -23,6 +23,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"prompt": json.dumps(embedding.name),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||
|
||||
}
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
|
||||
69
modules/ui_gradio_extensions.py
Normal file
69
modules/ui_gradio_extensions.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, shared, scripts
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
|
||||
def webpath(fn):
|
||||
if fn.startswith(script_path):
|
||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
||||
else:
|
||||
web_path = os.path.abspath(fn)
|
||||
|
||||
return f'file={web_path}?{os.path.getmtime(fn)}'
|
||||
|
||||
|
||||
def javascript_html():
|
||||
# Ensure localization is in `window` before scripts
|
||||
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
||||
|
||||
script_js = os.path.join(script_path, "script.js")
|
||||
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
|
||||
for script in scripts.list_scripts("javascript", ".js"):
|
||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
for script in scripts.list_scripts("javascript", ".mjs"):
|
||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
if shared.cmd_opts.theme:
|
||||
head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def css_html():
|
||||
head = ""
|
||||
|
||||
def stylesheet(fn):
|
||||
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
||||
|
||||
for cssfile in scripts.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
|
||||
head += stylesheet(cssfile)
|
||||
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
head += stylesheet(os.path.join(data_path, "user.css"))
|
||||
|
||||
return head
|
||||
|
||||
|
||||
def reload_javascript():
|
||||
js = javascript_html()
|
||||
css = css_html()
|
||||
|
||||
def template_response(*args, **kwargs):
|
||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
||||
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
return res
|
||||
|
||||
gr.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||
shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
||||
210
modules/ui_loadsave.py
Normal file
210
modules/ui_loadsave.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
|
||||
class UiLoadsave:
|
||||
"""allows saving and restorig default values for gradio components"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.ui_settings = {}
|
||||
self.component_mapping = {}
|
||||
self.error_loading = False
|
||||
self.finalized_ui = False
|
||||
|
||||
self.ui_defaults_view = None
|
||||
self.ui_defaults_apply = None
|
||||
self.ui_defaults_review = None
|
||||
|
||||
try:
|
||||
if os.path.exists(self.filename):
|
||||
self.ui_settings = self.read_from_file()
|
||||
except Exception as e:
|
||||
self.error_loading = True
|
||||
errors.display(e, "loading settings")
|
||||
|
||||
def add_component(self, path, x):
|
||||
"""adds component to the registry of tracked components"""
|
||||
|
||||
assert not self.finalized_ui
|
||||
|
||||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = f"{path}/{field}"
|
||||
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
saved_value = self.ui_settings.get(key, None)
|
||||
if saved_value is None:
|
||||
self.ui_settings[key] = getattr(obj, field)
|
||||
elif condition and not condition(saved_value):
|
||||
pass
|
||||
else:
|
||||
setattr(obj, field, saved_value)
|
||||
if init_field is not None:
|
||||
init_field(saved_value)
|
||||
|
||||
if field == 'value' and key not in self.component_mapping:
|
||||
self.component_mapping[key] = x
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
|
||||
if type(x) == gr.Slider:
|
||||
apply_field(x, 'value')
|
||||
apply_field(x, 'minimum')
|
||||
apply_field(x, 'maximum')
|
||||
apply_field(x, 'step')
|
||||
|
||||
if type(x) == gr.Radio:
|
||||
apply_field(x, 'value', lambda val: val in x.choices)
|
||||
|
||||
if type(x) == gr.Checkbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Textbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Number:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Dropdown:
|
||||
def check_dropdown(val):
|
||||
if getattr(x, 'multiselect', False):
|
||||
return all(value in x.choices for value in val)
|
||||
else:
|
||||
return val in x.choices
|
||||
|
||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||
|
||||
def check_tab_id(tab_id):
|
||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||
if type(tab_id) == str:
|
||||
tab_ids = [t.id for t in tab_items]
|
||||
return tab_id in tab_ids
|
||||
elif type(tab_id) == int:
|
||||
return 0 <= tab_id < len(tab_items)
|
||||
else:
|
||||
return False
|
||||
|
||||
if type(x) == gr.Tabs:
|
||||
apply_field(x, 'selected', check_tab_id)
|
||||
|
||||
def add_block(self, x, path=""):
|
||||
"""adds all components inside a gradio block x to the registry of tracked components"""
|
||||
|
||||
if hasattr(x, 'children'):
|
||||
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
||||
# Tabs element can't have a label, have to use elem_id instead
|
||||
self.add_component(f"{path}/Tabs@{x.elem_id}", x)
|
||||
for c in x.children:
|
||||
self.add_block(c, path)
|
||||
elif x.label is not None:
|
||||
self.add_component(f"{path}/{x.label}", x)
|
||||
elif isinstance(x, gr.Button) and x.value is not None:
|
||||
self.add_component(f"{path}/{x.value}", x)
|
||||
|
||||
def read_from_file(self):
|
||||
with open(self.filename, "r", encoding="utf8") as file:
|
||||
return json.load(file)
|
||||
|
||||
def write_to_file(self, current_ui_settings):
|
||||
with open(self.filename, "w", encoding="utf8") as file:
|
||||
json.dump(current_ui_settings, file, indent=4)
|
||||
|
||||
def dump_defaults(self):
|
||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||
|
||||
if self.error_loading and os.path.exists(self.filename):
|
||||
return
|
||||
|
||||
self.write_to_file(self.ui_settings)
|
||||
|
||||
def iter_changes(self, current_ui_settings, values):
|
||||
"""
|
||||
given a dictionary with defaults from a file and current values from gradio elements, returns
|
||||
an iterator over tuples of values that are not the same between the file and the current;
|
||||
tuple contents are: path, old value, new value
|
||||
"""
|
||||
|
||||
for (path, component), new_value in zip(self.component_mapping.items(), values):
|
||||
old_value = current_ui_settings.get(path)
|
||||
|
||||
choices = getattr(component, 'choices', None)
|
||||
if isinstance(new_value, int) and choices:
|
||||
if new_value >= len(choices):
|
||||
continue
|
||||
|
||||
new_value = choices[new_value]
|
||||
|
||||
if new_value == old_value:
|
||||
continue
|
||||
|
||||
if old_value is None and new_value == '' or new_value == []:
|
||||
continue
|
||||
|
||||
yield path, old_value, new_value
|
||||
|
||||
def ui_view(self, *values):
|
||||
text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
|
||||
|
||||
for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
|
||||
if old_value is None:
|
||||
old_value = "<span class='ui-defaults-none'>None</span>"
|
||||
|
||||
text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
|
||||
|
||||
if len(text) == 1:
|
||||
text.append("<tr><td colspan=3>No changes</td></tr>")
|
||||
|
||||
text.append("</tbody>")
|
||||
return "".join(text)
|
||||
|
||||
def ui_apply(self, *values):
|
||||
num_changed = 0
|
||||
|
||||
current_ui_settings = self.read_from_file()
|
||||
|
||||
for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
|
||||
num_changed += 1
|
||||
current_ui_settings[path] = new_value
|
||||
|
||||
if num_changed == 0:
|
||||
return "No changes."
|
||||
|
||||
self.write_to_file(current_ui_settings)
|
||||
|
||||
return f"Wrote {num_changed} changes."
|
||||
|
||||
def create_ui(self):
|
||||
"""creates ui elements for editing defaults UI, without adding any logic to them"""
|
||||
|
||||
gr.HTML(
|
||||
f"This page allows you to change default values in UI elements on other tabs.<br />"
|
||||
f"Make your changes, press 'View changes' to review the changed default values,<br />"
|
||||
f"then press 'Apply' to write them to {self.filename}.<br />"
|
||||
f"New defaults will apply after you restart the UI.<br />"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
|
||||
self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
|
||||
|
||||
self.ui_defaults_review = gr.HTML("")
|
||||
|
||||
def setup_ui(self):
|
||||
"""adds logic to elements created with create_ui; all add_block class must be made before this"""
|
||||
|
||||
assert not self.finalized_ui
|
||||
self.finalized_ui = True
|
||||
|
||||
self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
||||
self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
||||
@@ -1,5 +1,5 @@
|
||||
import gradio as gr
|
||||
from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
|
||||
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
|
||||
|
||||
|
||||
289
modules/ui_settings.py
Normal file
289
modules/ui_settings.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
||||
from modules.call_queue import wrap_gradio_call
|
||||
from modules.shared import opts
|
||||
from modules.ui_components import FormRow
|
||||
from modules.ui_gradio_extensions import reload_javascript
|
||||
|
||||
|
||||
def get_value_for_setting(key):
|
||||
value = getattr(opts, key)
|
||||
|
||||
info = opts.data_labels[key]
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
||||
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
||||
|
||||
return gr.update(value=value, **args)
|
||||
|
||||
|
||||
def create_setting_component(key, is_quicksettings=False):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
|
||||
info = opts.data_labels[key]
|
||||
t = type(info.default)
|
||||
|
||||
args = info.component_args() if callable(info.component_args) else info.component_args
|
||||
|
||||
if info.component is not None:
|
||||
comp = info.component
|
||||
elif t == str:
|
||||
comp = gr.Textbox
|
||||
elif t == int:
|
||||
comp = gr.Number
|
||||
elif t == bool:
|
||||
comp = gr.Checkbox
|
||||
else:
|
||||
raise Exception(f'bad options item type: {t} for key {key}')
|
||||
|
||||
elem_id = f"setting_{key}"
|
||||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
with FormRow():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class UiSettings:
|
||||
submit = None
|
||||
result = None
|
||||
interface = None
|
||||
components = None
|
||||
component_dict = None
|
||||
dummy_component = None
|
||||
quicksettings_list = None
|
||||
quicksettings_names = None
|
||||
text_settings = None
|
||||
|
||||
def run_settings(self, *args):
|
||||
changed = []
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
||||
assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
||||
if comp == self.dummy_component:
|
||||
continue
|
||||
|
||||
if opts.set(key, value):
|
||||
changed.append(key)
|
||||
|
||||
try:
|
||||
opts.save(shared.config_filename)
|
||||
except RuntimeError:
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
|
||||
|
||||
def run_settings_single(self, value, key):
|
||||
if not opts.same_type(value, opts.data_labels[key].default):
|
||||
return gr.update(visible=True), opts.dumpjson()
|
||||
|
||||
if not opts.set(key, value):
|
||||
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
|
||||
return get_value_for_setting(key), opts.dumpjson()
|
||||
|
||||
def create_ui(self, loadsave, dummy_component):
|
||||
self.components = []
|
||||
self.component_dict = {}
|
||||
self.dummy_component = dummy_component
|
||||
|
||||
shared.settings_components = self.component_dict
|
||||
|
||||
script_callbacks.ui_settings_callback()
|
||||
opts.reorder()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=6):
|
||||
self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
||||
with gr.Column():
|
||||
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
||||
|
||||
self.result = gr.HTML(elem_id="settings_result")
|
||||
|
||||
self.quicksettings_names = opts.quicksettings_list
|
||||
self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
|
||||
|
||||
self.quicksettings_list = []
|
||||
|
||||
previous_section = None
|
||||
current_tab = None
|
||||
current_row = None
|
||||
with gr.Tabs(elem_id="settings"):
|
||||
for i, (k, item) in enumerate(opts.data_labels.items()):
|
||||
section_must_be_skipped = item.section[0] is None
|
||||
|
||||
if previous_section != item.section and not section_must_be_skipped:
|
||||
elem_id, text = item.section
|
||||
|
||||
if current_tab is not None:
|
||||
current_row.__exit__()
|
||||
current_tab.__exit__()
|
||||
|
||||
gr.Group()
|
||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||
current_tab.__enter__()
|
||||
current_row = gr.Column(variant='compact')
|
||||
current_row.__enter__()
|
||||
|
||||
previous_section = item.section
|
||||
|
||||
if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
|
||||
self.quicksettings_list.append((i, k, item))
|
||||
self.components.append(dummy_component)
|
||||
elif section_must_be_skipped:
|
||||
self.components.append(dummy_component)
|
||||
else:
|
||||
component = create_setting_component(k)
|
||||
self.component_dict[k] = component
|
||||
self.components.append(component)
|
||||
|
||||
if current_tab is not None:
|
||||
current_row.__exit__()
|
||||
current_tab.__exit__()
|
||||
|
||||
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
||||
loadsave.create_ui()
|
||||
|
||||
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
||||
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
|
||||
with gr.Column(scale=1):
|
||||
sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
|
||||
with gr.Column(scale=100):
|
||||
pass
|
||||
|
||||
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=sd_models.unload_model_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=sd_models.reload_model_weights,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='function(){}'
|
||||
)
|
||||
|
||||
download_localization.click(
|
||||
fn=lambda: None,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='download_localization'
|
||||
)
|
||||
|
||||
def reload_scripts():
|
||||
scripts.reload_script_body_only()
|
||||
reload_javascript() # need to refresh the html page
|
||||
|
||||
reload_script_bodies.click(
|
||||
fn=reload_scripts,
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
restart_gradio.click(
|
||||
fn=shared.state.request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def check_file(x):
|
||||
if x is None:
|
||||
return ''
|
||||
|
||||
if sysinfo.check(x.decode('utf8', errors='ignore')):
|
||||
return 'Valid'
|
||||
|
||||
return 'Invalid'
|
||||
|
||||
sysinfo_check_file.change(
|
||||
fn=check_file,
|
||||
inputs=[sysinfo_check_file],
|
||||
outputs=[sysinfo_check_output],
|
||||
)
|
||||
|
||||
self.interface = settings_interface
|
||||
|
||||
def add_quicksettings(self):
|
||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||
for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
self.component_dict[k] = component
|
||||
|
||||
def add_functionality(self, demo):
|
||||
self.submit.click(
|
||||
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
|
||||
inputs=self.components,
|
||||
outputs=[self.text_settings, self.result],
|
||||
)
|
||||
|
||||
for _i, k, _item in self.quicksettings_list:
|
||||
component = self.component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||
change_handler(
|
||||
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, self.text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
||||
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
||||
|
||||
def get_settings_values():
|
||||
return [get_value_for_setting(key) for key in component_keys]
|
||||
|
||||
demo.load(
|
||||
fn=get_settings_values,
|
||||
inputs=[],
|
||||
outputs=[self.component_dict[k] for k in component_keys],
|
||||
queue=False,
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import gradio.components
|
||||
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
@@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename):
|
||||
|
||||
def check_tmp_file(gradio, filename):
|
||||
if hasattr(gradio, 'temp_file_sets'):
|
||||
return any([filename in fileset for fileset in gradio.temp_file_sets])
|
||||
return any(filename in fileset for fileset in gradio.temp_file_sets)
|
||||
|
||||
if hasattr(gradio, 'temp_dirs'):
|
||||
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
||||
@@ -31,13 +31,16 @@ def check_tmp_file(gradio, filename):
|
||||
return False
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
||||
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
register_tmp_file(shared.demo, already_saved_as)
|
||||
filename = already_saved_as
|
||||
|
||||
file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
|
||||
return file_obj
|
||||
if not shared.opts.save_images_add_number:
|
||||
filename += f'?{os.path.getmtime(already_saved_as)}'
|
||||
|
||||
return filename
|
||||
|
||||
if shared.opts.temp_dir != "":
|
||||
dir = shared.opts.temp_dir
|
||||
@@ -51,11 +54,11 @@ def save_pil_to_file(pil_image, dir=None):
|
||||
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||
return file_obj
|
||||
return file_obj.name
|
||||
|
||||
|
||||
# override save to file function so that it also writes PNG info
|
||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
||||
|
||||
|
||||
def on_tmpdir_changed():
|
||||
@@ -72,7 +75,7 @@ def cleanup_tmpdr():
|
||||
if temp_dir == "" or not os.path.isdir(temp_dir):
|
||||
return
|
||||
|
||||
for root, dirs, files in os.walk(temp_dir, topdown=False):
|
||||
for root, _, files in os.walk(temp_dir, topdown=False):
|
||||
for name in files:
|
||||
_, extension = os.path.splitext(name)
|
||||
if extension != ".png":
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
@@ -36,6 +34,7 @@ class Upscaler:
|
||||
self.half = not modules.shared.cmd_opts.no_half
|
||||
self.pre_pad = 0
|
||||
self.mod_scale = None
|
||||
self.model_download_path = None
|
||||
|
||||
if self.model_path is None and self.name:
|
||||
self.model_path = os.path.join(shared.models_path, self.name)
|
||||
@@ -43,9 +42,9 @@ class Upscaler:
|
||||
os.makedirs(self.model_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import cv2 # noqa: F401
|
||||
self.can_tile = True
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -54,10 +53,10 @@ class Upscaler:
|
||||
|
||||
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
||||
self.scale = scale
|
||||
dest_w = int(img.width * scale)
|
||||
dest_h = int(img.height * scale)
|
||||
dest_w = int((img.width * scale) // 8 * 8)
|
||||
dest_h = int((img.height * scale) // 8 * 8)
|
||||
|
||||
for i in range(3):
|
||||
for _ in range(3):
|
||||
shape = (img.width, img.height)
|
||||
|
||||
img = self.do_upscale(img, selected_model)
|
||||
@@ -78,7 +77,7 @@ class Upscaler:
|
||||
pass
|
||||
|
||||
def find_models(self, ext_filter=None) -> list:
|
||||
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
|
||||
|
||||
def update_status(self, prompt):
|
||||
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||
from transformers import BertPreTrainedModel, BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
@@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
||||
config_class= RobertaSeriesConfig
|
||||
|
||||
Reference in New Issue
Block a user