mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-02-08 09:00:12 +00:00
Merge branch 'dev' into improve-frontend-responsiveness
This commit is contained in:
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@@ -6,7 +6,6 @@ import uvicorn
|
||||
import gradio as gr
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
@@ -16,7 +15,8 @@ from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||
from modules.api.models import *
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
@@ -26,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
from typing import Dict, List, Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
return [script.title().lower() for script in scripts].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
@@ -49,20 +52,23 @@ def validate_sampler_name(name):
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
return reqDict
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as err:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
@@ -93,6 +99,7 @@ def encode_pil_to_base64(image):
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = True
|
||||
try:
|
||||
@@ -100,7 +107,7 @@ def api_middleware(app: FastAPI):
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
except:
|
||||
except Exception:
|
||||
import traceback
|
||||
rich_available = False
|
||||
|
||||
@@ -131,8 +138,8 @@ def api_middleware(app: FastAPI):
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
print(f"API error: {request.method}: {request.url} {err}")
|
||||
if rich_available:
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
@@ -158,7 +165,7 @@ def api_middleware(app: FastAPI):
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credentials = dict()
|
||||
self.credentials = {}
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credentials[user] = password
|
||||
@@ -167,36 +174,37 @@ class Api:
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
@@ -220,17 +228,25 @@ class Api:
|
||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||
script = script_runner.selectable_scripts[script_idx]
|
||||
return script, script_idx
|
||||
|
||||
def get_scripts_list(self):
|
||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||
|
||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
||||
def get_scripts_list(self):
|
||||
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||
|
||||
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||
|
||||
def get_script_info(self):
|
||||
res = []
|
||||
|
||||
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
||||
res += [script.api_info for script in script_list if script.api_info is not None]
|
||||
|
||||
return res
|
||||
|
||||
def get_script(self, script_name, script_runner):
|
||||
if script_name is None or script_name == "":
|
||||
return None, None
|
||||
|
||||
|
||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||
return script_runner.scripts[script_idx]
|
||||
|
||||
@@ -265,17 +281,19 @@ class Api:
|
||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||
if alwayson_script == None:
|
||||
if alwayson_script is None:
|
||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||
# Selectable script in always on script param check
|
||||
if alwayson_script.alwayson == False:
|
||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
||||
if alwayson_script.alwayson is False:
|
||||
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
||||
# min between arg length in scriptrunner and arg length in the request
|
||||
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||
script_runner = scripts.scripts_txt2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
@@ -309,7 +327,7 @@ class Api:
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts != None:
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
@@ -319,9 +337,9 @@ class Api:
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
@@ -366,7 +384,7 @@ class Api:
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
|
||||
shared.state.begin()
|
||||
if selectable_scripts != None:
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
@@ -380,9 +398,9 @@ class Api:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
@@ -390,31 +408,26 @@ class Api:
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
image_list = reqDict.pop('imageList', [])
|
||||
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return PNGInfoResponse(info="")
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
if geninfo is None:
|
||||
@@ -422,13 +435,13 @@ class Api:
|
||||
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
return PNGInfoResponse(info=geninfo, items=items)
|
||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
@@ -450,9 +463,9 @@ class Api:
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
@@ -469,7 +482,7 @@ class Api:
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return InterrogateResponse(caption=processed)
|
||||
return models.InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
@@ -574,36 +587,36 @@ class Api:
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess complete')
|
||||
return models.PreprocessResponse(info = 'preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||
except FileNotFoundError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
@@ -621,10 +634,10 @@ class Api:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
||||
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
@@ -645,14 +658,15 @@ class Api:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
||||
except AssertionError as msg:
|
||||
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||
except AssertionError:
|
||||
shared.state.end()
|
||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
||||
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
import os, psutil
|
||||
import os
|
||||
import psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||
@@ -679,10 +693,10 @@ class Api:
|
||||
'events': warnings,
|
||||
}
|
||||
else:
|
||||
cuda = { 'error': 'unavailable' }
|
||||
cuda = {'error': 'unavailable'}
|
||||
except Exception as err:
|
||||
cuda = { 'error': f'{err}' }
|
||||
return MemoryResponse(ram = ram, cuda = cuda)
|
||||
cuda = {'error': f'{err}'}
|
||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
||||
@@ -223,8 +223,9 @@ for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
if _options[key].default is not None:
|
||||
_type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
|
||||
class ScriptsList(BaseModel):
|
||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
||||
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||
|
||||
|
||||
class ScriptArg(BaseModel):
|
||||
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
|
||||
|
||||
class ScriptInfo(BaseModel):
|
||||
name: str = Field(default=None, title="Name", description="Script name")
|
||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
@@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
progress.record_results(id_task, res)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
||||
@@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
argStr = f"Arguments: {args} {kwargs}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
@@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||
error_message = f'{type(e).__name__}: {e}'
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -95,9 +95,12 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
from modules.codeformer.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
return out, logits, lq_feat
|
||||
|
||||
@@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
@@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
@@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
@@ -317,29 +315,29 @@ class Generator(nn.Module):
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.attn_resolutions = attn_resolutions or [16]
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
@@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
@@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
return self.main(x)
|
||||
|
||||
@@ -33,11 +33,9 @@ def setup_model(dirname):
|
||||
try:
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules.codeformer.codeformer_arch import CodeFormer
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
net_class = CodeFormer
|
||||
|
||||
@@ -96,7 +94,7 @@ def setup_model(dirname):
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
202
modules/config_states.py
Normal file
202
modules/config_states.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Supports saving and restoring webui and extensions from a known working set of commits
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import json
|
||||
import time
|
||||
import tqdm
|
||||
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import git
|
||||
|
||||
from modules import shared, extensions
|
||||
from modules.paths_internal import script_path, config_states_dir
|
||||
|
||||
|
||||
all_config_states = OrderedDict()
|
||||
|
||||
|
||||
def list_config_states():
|
||||
global all_config_states
|
||||
|
||||
all_config_states.clear()
|
||||
os.makedirs(config_states_dir, exist_ok=True)
|
||||
|
||||
config_states = []
|
||||
for filename in os.listdir(config_states_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(config_states_dir, filename)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
j = json.load(f)
|
||||
j["filepath"] = path
|
||||
config_states.append(j)
|
||||
|
||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||
|
||||
for cs in config_states:
|
||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||
name = cs.get("name", "Config")
|
||||
full_name = f"{name}: {timestamp}"
|
||||
all_config_states[full_name] = cs
|
||||
|
||||
return all_config_states
|
||||
|
||||
|
||||
def get_webui_config():
|
||||
webui_repo = None
|
||||
|
||||
try:
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
webui_remote = None
|
||||
webui_commit_hash = None
|
||||
webui_commit_date = None
|
||||
webui_branch = None
|
||||
if webui_repo and not webui_repo.bare:
|
||||
try:
|
||||
webui_remote = next(webui_repo.remote().urls, None)
|
||||
head = webui_repo.head.commit
|
||||
webui_commit_date = webui_repo.head.commit.committed_date
|
||||
webui_commit_hash = head.hexsha
|
||||
webui_branch = webui_repo.active_branch.name
|
||||
|
||||
except Exception:
|
||||
webui_remote = None
|
||||
|
||||
return {
|
||||
"remote": webui_remote,
|
||||
"commit_hash": webui_commit_hash,
|
||||
"commit_date": webui_commit_date,
|
||||
"branch": webui_branch,
|
||||
}
|
||||
|
||||
|
||||
def get_extension_config():
|
||||
ext_config = {}
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext.read_info_from_repo()
|
||||
|
||||
entry = {
|
||||
"name": ext.name,
|
||||
"path": ext.path,
|
||||
"enabled": ext.enabled,
|
||||
"is_builtin": ext.is_builtin,
|
||||
"remote": ext.remote,
|
||||
"commit_hash": ext.commit_hash,
|
||||
"commit_date": ext.commit_date,
|
||||
"branch": ext.branch,
|
||||
"have_info_from_repo": ext.have_info_from_repo
|
||||
}
|
||||
|
||||
ext_config[ext.name] = entry
|
||||
|
||||
return ext_config
|
||||
|
||||
|
||||
def get_config():
|
||||
creation_time = datetime.now().timestamp()
|
||||
webui_config = get_webui_config()
|
||||
ext_config = get_extension_config()
|
||||
|
||||
return {
|
||||
"created_at": creation_time,
|
||||
"webui": webui_config,
|
||||
"extensions": ext_config
|
||||
}
|
||||
|
||||
|
||||
def restore_webui_config(config):
|
||||
print("* Restoring webui state...")
|
||||
|
||||
if "webui" not in config:
|
||||
print("Error: No webui data saved to config")
|
||||
return
|
||||
|
||||
webui_config = config["webui"]
|
||||
|
||||
if "commit_hash" not in webui_config:
|
||||
print("Error: No commit saved to webui config")
|
||||
return
|
||||
|
||||
webui_commit_hash = webui_config.get("commit_hash", None)
|
||||
webui_repo = None
|
||||
|
||||
try:
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return
|
||||
|
||||
try:
|
||||
webui_repo.git.fetch(all=True)
|
||||
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||
except Exception:
|
||||
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def restore_extension_config(config):
|
||||
print("* Restoring extension state...")
|
||||
|
||||
if "extensions" not in config:
|
||||
print("Error: No extension data saved to config")
|
||||
return
|
||||
|
||||
ext_config = config["extensions"]
|
||||
|
||||
results = []
|
||||
disabled = []
|
||||
|
||||
for ext in tqdm.tqdm(extensions.extensions):
|
||||
if ext.is_builtin:
|
||||
continue
|
||||
|
||||
ext.read_info_from_repo()
|
||||
current_commit = ext.commit_hash
|
||||
|
||||
if ext.name not in ext_config:
|
||||
ext.disabled = True
|
||||
disabled.append(ext.name)
|
||||
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
||||
continue
|
||||
|
||||
entry = ext_config[ext.name]
|
||||
|
||||
if "commit_hash" in entry and entry["commit_hash"]:
|
||||
try:
|
||||
ext.fetch_and_reset_hard(entry["commit_hash"])
|
||||
ext.read_info_from_repo()
|
||||
if current_commit != entry["commit_hash"]:
|
||||
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
||||
except Exception as ex:
|
||||
results.append((ext, current_commit[:8], False, ex))
|
||||
else:
|
||||
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
||||
|
||||
if not entry.get("enabled", False):
|
||||
ext.disabled = True
|
||||
disabled.append(ext.name)
|
||||
else:
|
||||
ext.disabled = False
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
print("* Finished restoring extensions. Results:")
|
||||
for ext, prev_commit, success, result in results:
|
||||
if success:
|
||||
print(f" + {ext.name}: {prev_commit} -> {result}")
|
||||
else:
|
||||
print(f" ! {ext.name}: FAILURE ({result})")
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
@@ -79,7 +78,7 @@ class DeepDanbooru:
|
||||
|
||||
res = []
|
||||
|
||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
|
||||
@@ -65,7 +65,7 @@ def enable_tf32():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
||||
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -92,14 +92,18 @@ def cond_cast_float(input):
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
from modules.shared import opts
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device.type == 'mps':
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
if device.type == 'mps':
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import shared, modelloader, images, devices
|
||||
from modules import modelloader, images, devices
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
@@ -16,9 +16,7 @@ def mod2normal(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if 'conv_first.weight' in state_dict:
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
@@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="%s.pth" % self.model_name,
|
||||
progress=True)
|
||||
filename = load_file_from_url(
|
||||
url=self.model_url,
|
||||
model_dir=self.model_path,
|
||||
file_name=f"{self.model_name}.pth",
|
||||
progress=True,
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
print(f"Unable to load {self.model_path} from {filename}")
|
||||
return None
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
@@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
@@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
@@ -261,10 +260,10 @@ class Upsample(nn.Module):
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
info = f'scale_factor={self.scale_factor}'
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
info = f'size={self.size}'
|
||||
info += f', mode={self.mode}'
|
||||
return info
|
||||
|
||||
|
||||
@@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
@@ -372,7 +371,7 @@ def norm(norm_type, nc):
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
@@ -388,7 +387,7 @@ def pad(pad_type, padding):
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||
return layer
|
||||
|
||||
|
||||
@@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
from torchvision.ops import DeformConv2d # not tested
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import time
|
||||
import git
|
||||
|
||||
from modules import shared
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
|
||||
extensions = []
|
||||
|
||||
@@ -24,6 +24,8 @@ def active():
|
||||
|
||||
|
||||
class Extension:
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
@@ -31,16 +33,24 @@ class Extension:
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
self.commit_hash = ''
|
||||
self.commit_date = None
|
||||
self.version = ''
|
||||
self.branch = None
|
||||
self.remote = None
|
||||
self.have_info_from_repo = False
|
||||
|
||||
def read_info_from_repo(self):
|
||||
if self.have_info_from_repo:
|
||||
if self.is_builtin or self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.have_info_from_repo = True
|
||||
with self.lock:
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.do_read_info_from_repo()
|
||||
|
||||
def do_read_info_from_repo(self):
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(self.path, ".git")):
|
||||
@@ -55,13 +65,18 @@ class Extension:
|
||||
try:
|
||||
self.status = 'unknown'
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
head = repo.head.commit
|
||||
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
||||
self.version = f'{head.hexsha[:8]} ({ts})'
|
||||
self.commit_date = repo.head.commit.committed_date
|
||||
if repo.active_branch:
|
||||
self.branch = repo.active_branch.name
|
||||
self.commit_hash = repo.head.commit.hexsha
|
||||
self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
|
||||
|
||||
except Exception:
|
||||
except Exception as ex:
|
||||
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
||||
self.remote = None
|
||||
|
||||
self.have_info_from_repo = True
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
@@ -82,18 +97,30 @@ class Extension:
|
||||
for fetch in repo.remote().fetch(dry_run=True):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
self.status = "new commits"
|
||||
return
|
||||
|
||||
try:
|
||||
origin = repo.rev_parse('origin')
|
||||
if repo.head.commit != origin:
|
||||
self.can_update = True
|
||||
self.status = "behind HEAD"
|
||||
return
|
||||
except Exception:
|
||||
self.can_update = False
|
||||
self.status = "unknown (remote error)"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self):
|
||||
def fetch_and_reset_hard(self, commit='origin'):
|
||||
repo = git.Repo(self.path)
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch(all=True)
|
||||
repo.git.reset('origin', hard=True)
|
||||
repo.git.reset(commit, hard=True)
|
||||
self.have_info_from_repo = False
|
||||
|
||||
|
||||
def list_extensions():
|
||||
|
||||
@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
for extra_network_name in extra_network_data:
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from modules import extra_networks, shared, extra_networks
|
||||
from modules import extra_networks, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
@@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
|
||||
|
||||
import torch
|
||||
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
|
||||
return tensor
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
@@ -135,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = f"Loading B"
|
||||
shared.state.textinfo = "Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
else:
|
||||
theta_1 = None
|
||||
|
||||
if theta_func1:
|
||||
shared.state.textinfo = f"Loading C"
|
||||
shared.state.textinfo = "Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
@@ -198,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
@@ -241,13 +242,58 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
metadata = None
|
||||
|
||||
if save_metadata:
|
||||
metadata = {"format": "pt"}
|
||||
|
||||
merge_recipe = {
|
||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||
"primary_model_hash": primary_model_info.sha256,
|
||||
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
||||
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
||||
"interp_method": interp_method,
|
||||
"multiplier": multiplier,
|
||||
"save_as_half": save_as_half,
|
||||
"custom_name": custom_name,
|
||||
"config_source": config_source,
|
||||
"bake_in_vae": bake_in_vae,
|
||||
"discard_weights": discard_weights,
|
||||
"is_inpainting": result_is_inpainting_model,
|
||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||
}
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
|
||||
sd_merge_models = {}
|
||||
|
||||
def add_model_metadata(checkpoint_info):
|
||||
checkpoint_info.calculate_shorthash()
|
||||
sd_merge_models[checkpoint_info.sha256] = {
|
||||
"name": checkpoint_info.name,
|
||||
"legacy_hash": checkpoint_info.hash,
|
||||
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||
}
|
||||
|
||||
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||
|
||||
add_model_metadata(primary_model_info)
|
||||
if secondary_model_info:
|
||||
add_model_metadata(secondary_model_info)
|
||||
if tertiary_model_info:
|
||||
add_model_metadata(tertiary_model_info)
|
||||
|
||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
sd_models.list_models()
|
||||
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
||||
if created_model:
|
||||
created_model.calculate_shorthash()
|
||||
|
||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import base64
|
||||
import html
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
@@ -23,14 +19,14 @@ registered_param_bindings = []
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
self.paste_field_names = paste_field_names
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
def reset():
|
||||
@@ -59,6 +55,7 @@ def image_from_url_text(filedata):
|
||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
filename = filename.rsplit('?', 1)[0]
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
@@ -129,6 +126,7 @@ def connect_paste_params_buttons():
|
||||
_js=jsfunc,
|
||||
inputs=[binding.source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
if binding.source_text_component is not None and fields is not None:
|
||||
@@ -140,6 +138,7 @@ def connect_paste_params_buttons():
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
binding.paste_button.click(
|
||||
@@ -147,6 +146,7 @@ def connect_paste_params_buttons():
|
||||
_js=f"switch_to_{binding.tabname}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("Negative prompt:"):
|
||||
done_with_prompt = True
|
||||
@@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[k+"-1"] = m.group(1)
|
||||
res[k+"-2"] = m.group(2)
|
||||
res[f"{k}-1"] = m.group(1)
|
||||
res[f"{k}-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
if "RNG" not in res:
|
||||
res["RNG"] = "GPU"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
@@ -304,6 +308,10 @@ infotext_to_setting_name_mapping = [
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
('Token merging ratio', 'token_merging_ratio'),
|
||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||
('RNG', 'randn_source'),
|
||||
('NGMS', 's_min_uncond'),
|
||||
]
|
||||
|
||||
|
||||
@@ -403,12 +411,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
fn=paste_func,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
show_progress=False,
|
||||
)
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ def setup_model(dirname):
|
||||
|
||||
try:
|
||||
from gfpgan import GFPGANer
|
||||
from facexlib import detection, parsing
|
||||
from facexlib import detection, parsing # noqa: F401
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
|
||||
@@ -13,7 +13,7 @@ cache_data = None
|
||||
|
||||
|
||||
def dump_cache():
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
@@ -22,7 +22,7 @@ def cache(subsection):
|
||||
global cache_data
|
||||
|
||||
if cache_data is None:
|
||||
with filelock.FileLock(cache_filename+".lock"):
|
||||
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections import deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
@@ -178,34 +177,34 @@ class Hypernetwork:
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train(self, mode=True):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.train(mode=mode)
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
@@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||
if clip_grad:
|
||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
@@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
print(e)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
for _ in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
@@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(hypernetwork.step)
|
||||
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if use_weight:
|
||||
@@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
loss_logging.append(_loss_step)
|
||||
if clip_grad:
|
||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
@@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
_loss_step = 0
|
||||
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
|
||||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import html
|
||||
import os
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
||||
@@ -13,17 +13,24 @@ import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks, errors
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.paths_internal import roboto_ttf_file
|
||||
from modules.shared import opts
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
def get_font(fontsize: int):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||
|
||||
|
||||
def image_grid(imgs, batch_size=1, rows=None):
|
||||
if rows is None:
|
||||
if opts.n_rows > 0:
|
||||
@@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def get_font(fontsize):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(Roboto, fontsize)
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
for i, line in enumerate(lines):
|
||||
for line in lines:
|
||||
fnt = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
@@ -318,6 +319,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = 128
|
||||
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||
|
||||
|
||||
def sanitize_filename_part(text, replace_spaces=True):
|
||||
@@ -352,6 +354,11 @@ class FilenameGenerator:
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
'prompt_words': lambda self: self.prompt_words(),
|
||||
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
|
||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
@@ -361,6 +368,22 @@ class FilenameGenerator:
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
|
||||
def hasprompt(self, *args):
|
||||
lower = self.prompt.lower()
|
||||
if self.p is None or self.prompt is None:
|
||||
return None
|
||||
outres = ""
|
||||
for arg in args:
|
||||
if arg != "":
|
||||
division = arg.split("|")
|
||||
expected = division[0].lower()
|
||||
default = division[1] if len(division) > 1 else ""
|
||||
if lower.find(expected) >= 0:
|
||||
outres = f'{outres}{expected}'
|
||||
else:
|
||||
outres = outres if default == "" else f'{outres}{default}'
|
||||
return sanitize_filename_part(outres)
|
||||
|
||||
def prompt_no_style(self):
|
||||
if self.p is None or self.prompt is None:
|
||||
return None
|
||||
@@ -387,13 +410,13 @@ class FilenameGenerator:
|
||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||
try:
|
||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
time_zone = None
|
||||
|
||||
time_zone_time = time_datetime.astimezone(time_zone)
|
||||
try:
|
||||
formatted_time = time_zone_time.strftime(time_format)
|
||||
except (ValueError, TypeError) as _:
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||
|
||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||
@@ -403,9 +426,9 @@ class FilenameGenerator:
|
||||
|
||||
for m in re_pattern.finditer(x):
|
||||
text, pattern = m.groups()
|
||||
res += text
|
||||
|
||||
if pattern is None:
|
||||
res += text
|
||||
continue
|
||||
|
||||
pattern_args = []
|
||||
@@ -426,11 +449,13 @@ class FilenameGenerator:
|
||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if replacement is not None:
|
||||
res += str(replacement)
|
||||
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||
continue
|
||||
elif replacement is not None:
|
||||
res += text + str(replacement)
|
||||
continue
|
||||
|
||||
res += f'[{pattern}]'
|
||||
res += f'{text}[{pattern}]'
|
||||
|
||||
return res
|
||||
|
||||
@@ -443,20 +468,57 @@ def get_next_sequence_number(path, basename):
|
||||
"""
|
||||
result = -1
|
||||
if basename != '':
|
||||
basename = basename + "-"
|
||||
basename = f"{basename}-"
|
||||
|
||||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(l[0]), result)
|
||||
result = max(int(parts[0]), result)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
|
||||
if extension is None:
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
existing_pnginfo = existing_pnginfo or {}
|
||||
if opts.enable_pnginfo:
|
||||
existing_pnginfo['parameters'] = geninfo
|
||||
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in (existing_pnginfo or {}).items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert("RGB")
|
||||
elif image.mode == 'I;16':
|
||||
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
if opts.enable_pnginfo and geninfo is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, filename)
|
||||
else:
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||
"""Save an image.
|
||||
|
||||
@@ -512,7 +574,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
add_number = opts.save_images_add_number or file_decoration == ''
|
||||
|
||||
if file_decoration != "" and add_number:
|
||||
file_decoration = "-" + file_decoration
|
||||
file_decoration = f"-{file_decoration}"
|
||||
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
@@ -541,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
|
||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
temp_file_path = filename_without_extension + ".tmp"
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
"""
|
||||
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
"""
|
||||
temp_file_path = f"{filename_without_extension}.tmp"
|
||||
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
if opts.enable_pnginfo:
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
elif image_to_save.mode == 'I;16':
|
||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, temp_file_path)
|
||||
else:
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
# atomically rename the file with correct extension
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
@@ -602,7 +639,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
if opts.save_txt and info is not None:
|
||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||
file.write(info + "\n")
|
||||
file.write(f"{info}\n")
|
||||
else:
|
||||
txt_fullfn = None
|
||||
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
|
||||
from modules import devices, sd_samplers
|
||||
from modules import sd_samplers
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.images as images
|
||||
import modules.scripts
|
||||
|
||||
|
||||
@@ -46,7 +42,11 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
img = Image.open(image)
|
||||
try:
|
||||
img = Image.open(image)
|
||||
except UnidentifiedImageError as e:
|
||||
print(e)
|
||||
continue
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
@@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||
# if not found use first one ("same mask for all images" use-case)
|
||||
if not mask_image_path in inpaint_masks:
|
||||
if mask_image_path not in inpaint_masks:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
mask_image = Image.open(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
@@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@@ -114,6 +114,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
if selected_scale_tab == 1:
|
||||
assert image, "Can't scale by because no image is selected"
|
||||
|
||||
width = int(image.width * scale_by)
|
||||
height = int(image.height * scale_by)
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
@@ -151,7 +157,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
override_settings=override_settings,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
|
||||
@@ -11,7 +11,6 @@ import torch.hub
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
|
||||
blip_image_eval_size = 384
|
||||
@@ -28,11 +27,11 @@ def category_types():
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
|
||||
tmpdir = content_dir + "_tmp"
|
||||
tmpdir = f"{content_dir}_tmp"
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
try:
|
||||
os.makedirs(tmpdir)
|
||||
os.makedirs(tmpdir, exist_ok=True)
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
@@ -41,7 +40,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.remove(tmpdir)
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
@@ -160,7 +159,7 @@ class InterrogateModels:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
||||
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
@@ -208,13 +207,13 @@ class InterrogateModels:
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
for name, topn, items in self.categories():
|
||||
matches = self.rank(image_features, items, top_count=topn)
|
||||
for cat in self.categories():
|
||||
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
res += f", ({match}:{score/100:.3f})"
|
||||
else:
|
||||
res += ", " + match
|
||||
res += f", {match}"
|
||||
|
||||
except Exception:
|
||||
print("Error interrogating", file=sys.stderr)
|
||||
|
||||
@@ -23,7 +23,7 @@ def list_localizations(dirname):
|
||||
localizations[fn] = file.path
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
def localization_js(current_localization_name: str) -> str:
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
data = {}
|
||||
if fn is not None:
|
||||
@@ -34,4 +34,4 @@ def localization_js(current_localization_name):
|
||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return f"var localization = {json.dumps(data)}\n"
|
||||
return f"window.localization = {json.dumps(data)}"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import platform
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
@@ -43,7 +42,7 @@ if has_mps:
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
@@ -54,6 +53,11 @@ if has_mps:
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
if platform.processor() == 'i386':
|
||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
|
||||
@@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
def get_crop_region(mask, pad=0):
|
||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||
|
||||
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
@@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
"""
|
||||
output = []
|
||||
|
||||
if ext_filter is None:
|
||||
ext_filter = []
|
||||
|
||||
try:
|
||||
places = []
|
||||
|
||||
@@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
places.append(model_path)
|
||||
|
||||
for place in places:
|
||||
if os.path.exists(place):
|
||||
for file in glob.iglob(place + '**/**', recursive=True):
|
||||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
model_name, extension = os.path.splitext(file)
|
||||
if extension not in ext_filter:
|
||||
continue
|
||||
if file not in output:
|
||||
output.append(full_path)
|
||||
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||
continue
|
||||
if full_path not in output:
|
||||
output.append(full_path)
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
@@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||
try:
|
||||
shutil.move(fullpath, dest_path)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
if len(os.listdir(src_path)) == 0:
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
builtin_upscaler_classes = []
|
||||
forbidden_upscaler_classes = set()
|
||||
|
||||
|
||||
def list_builtin_upscalers():
|
||||
load_upscalers()
|
||||
|
||||
builtin_upscaler_classes.clear()
|
||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
||||
|
||||
|
||||
def forbid_loaded_nonbuiltin_upscalers():
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls not in builtin_upscaler_classes:
|
||||
forbidden_upscaler_classes.add(cls)
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
@@ -155,15 +126,22 @@ def load_upscalers():
|
||||
full_model = f"modules.{model_name}_model"
|
||||
try:
|
||||
importlib.import_module(full_model)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
datas = []
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls in forbidden_upscaler_classes:
|
||||
continue
|
||||
|
||||
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||
# up with two copies of those classes. The newest copy will always be the last in the list,
|
||||
# so we go from end to beginning and ignore duplicates
|
||||
used_classes = {}
|
||||
for cls in reversed(Upscaler.__subclasses__()):
|
||||
classname = str(cls)
|
||||
if classname not in used_classes:
|
||||
used_classes[classname] = cls
|
||||
|
||||
for cls in reversed(used_classes.values()):
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
scaler = cls(commandline_options.get(cmd_name, None))
|
||||
|
||||
@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||
|
||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||
if self.use_ema and not load_ema:
|
||||
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||
ignore_keys = ignore_keys or []
|
||||
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
print(f"Deleting key {k} from state_dict.")
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
@@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
|
||||
_, loss_dict_no_ema = self.shared_step(batch)
|
||||
with self.ema_scope():
|
||||
_, loss_dict_ema = self.shared_step(batch)
|
||||
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
||||
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
||||
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||
|
||||
@@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||
log = dict()
|
||||
log = {}
|
||||
x = self.get_input(batch, self.first_stage_key)
|
||||
N = min(x.shape[0], N)
|
||||
n_row = min(x.shape[0], n_row)
|
||||
@@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
||||
log["inputs"] = x
|
||||
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
diffusion_row = []
|
||||
x_start = x[:n_row]
|
||||
|
||||
for t in range(self.num_timesteps):
|
||||
@@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
||||
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
try:
|
||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||
except:
|
||||
except Exception:
|
||||
self.num_downs = 0
|
||||
if not scale_by_std:
|
||||
self.scale_factor = scale_factor
|
||||
@@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
return self.p_losses(x, c, t, *args, **kwargs)
|
||||
|
||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
||||
def rescale_bbox(bbox):
|
||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
||||
return x0, y0, w, h
|
||||
|
||||
return [rescale_bbox(b) for b in bboxes]
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
@@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
||||
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||
else:
|
||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||
|
||||
@@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if i % log_every_t == 0 or i == timesteps - 1:
|
||||
intermediates.append(x0_partial)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(img, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if i % log_every_t == 0 or i == timesteps - 1:
|
||||
intermediates.append(img)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(img, i)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(img, i)
|
||||
|
||||
if return_intermediates:
|
||||
return img, intermediates
|
||||
@@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
||||
if cond is not None:
|
||||
if isinstance(cond, dict):
|
||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
||||
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||
else:
|
||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||
return self.p_sample_loop(cond,
|
||||
@@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
use_ddim = False
|
||||
|
||||
log = dict()
|
||||
log = {}
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
@@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
diffusion_row = []
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
@@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
if inpaint:
|
||||
# make a simple center square
|
||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
||||
h, w = z.shape[2], z.shape[3]
|
||||
mask = torch.ones(N, h, w).to(self.device)
|
||||
# zeros will be filled in
|
||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||
@@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
||||
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
||||
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||
|
||||
key = 'train' if self.training else 'validation'
|
||||
dset = self.trainer.datamodule.datasets[key]
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .sampler import UniPCSampler
|
||||
from .sampler import UniPCSampler # noqa: F401
|
||||
|
||||
@@ -54,7 +54,8 @@ class UniPCSampler(object):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from tqdm.auto import trange
|
||||
import tqdm
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
@@ -94,7 +93,7 @@ class NoiseScheduleVP:
|
||||
"""
|
||||
|
||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
||||
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule == 'discrete':
|
||||
@@ -179,13 +178,13 @@ def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
model_kwargs=None,
|
||||
guidance_type="uncond",
|
||||
#condition=None,
|
||||
#unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs={},
|
||||
classifier_kwargs=None,
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
|
||||
@@ -276,6 +275,9 @@ def model_wrapper(
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
classifier_kwargs = classifier_kwargs or {}
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
@@ -342,7 +344,7 @@ def model_wrapper(
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = dict()
|
||||
c_in = {}
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
@@ -353,7 +355,7 @@ def model_wrapper(
|
||||
unconditional_condition[k],
|
||||
condition[k]])
|
||||
elif isinstance(condition, list):
|
||||
c_in = list()
|
||||
c_in = []
|
||||
assert isinstance(unconditional_condition, list)
|
||||
for i in range(len(condition)):
|
||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||
@@ -469,7 +471,7 @@ class UniPC:
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
||||
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
@@ -757,40 +759,44 @@ class UniPC:
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
for step in trange(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
with tqdm.tqdm(total=steps) as pbar:
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
pbar.update()
|
||||
|
||||
for step in range(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
pbar.update()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
|
||||
@@ -7,12 +7,24 @@ def connect(token, port, region):
|
||||
else:
|
||||
if ':' in token:
|
||||
# token = authtoken:username:password
|
||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
||||
token = token.split(':')[0]
|
||||
token, username, password = token.split(':', 2)
|
||||
account = f"{username}:{password}"
|
||||
|
||||
config = conf.PyngrokConfig(
|
||||
auth_token=token, region=region
|
||||
)
|
||||
|
||||
# Guard for existing tunnels
|
||||
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||
if existing:
|
||||
for established in existing:
|
||||
# Extra configuration in the case that the user is also using ngrok for other tunnels
|
||||
if established.config['addr'][-4:] == str(port):
|
||||
public_url = existing[0].public_url
|
||||
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||
'You can use this link after the launch is complete.')
|
||||
return
|
||||
|
||||
try:
|
||||
if account is None:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
|
||||
import modules.safe
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
@@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
break
|
||||
|
||||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
|
||||
@@ -2,8 +2,14 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
@@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
@@ -20,3 +26,6 @@ data_path = cmd_opts_pre.data_dir
|
||||
models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
config_states_dir = os.path.join(script_path, "config_states")
|
||||
|
||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||
|
||||
@@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
image = Image.open(img)
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
fn = ''
|
||||
else:
|
||||
image = Image.open(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.orig_name)[0]
|
||||
image_data.append(image)
|
||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
||||
image_names.append(fn)
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -10,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
import random
|
||||
import cv2
|
||||
from skimage import exposure
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
@@ -105,7 +106,7 @@ class StableDiffusionProcessing:
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
@@ -140,6 +141,7 @@ class StableDiffusionProcessing:
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
@@ -148,6 +150,8 @@ class StableDiffusionProcessing:
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
self.token_merging_ratio = 0
|
||||
self.token_merging_ratio_hr = 0
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
@@ -162,6 +166,9 @@ class StableDiffusionProcessing:
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
self.is_hr_pass = False
|
||||
self.sampler = None
|
||||
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
@@ -270,6 +277,12 @@ class StableDiffusionProcessing:
|
||||
def close(self):
|
||||
self.sampler = None
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
if for_hr:
|
||||
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
||||
|
||||
return self.token_merging_ratio or opts.token_merging_ratio
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
@@ -299,6 +312,8 @@ class Processed:
|
||||
self.styles = p.styles
|
||||
self.job_timestamp = state.job_timestamp
|
||||
self.clip_skip = opts.CLIP_stop_at_last_layers
|
||||
self.token_merging_ratio = p.token_merging_ratio
|
||||
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
||||
|
||||
self.eta = p.eta
|
||||
self.ddim_discretize = p.ddim_discretize
|
||||
@@ -306,6 +321,7 @@ class Processed:
|
||||
self.s_tmin = p.s_tmin
|
||||
self.s_tmax = p.s_tmax
|
||||
self.s_noise = p.s_noise
|
||||
self.s_min_uncond = p.s_min_uncond
|
||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
@@ -356,6 +372,9 @@ class Processed:
|
||||
def infotext(self, p: StableDiffusionProcessing, index):
|
||||
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
@@ -454,10 +473,27 @@ def fix_seed(p):
|
||||
p.subseed = get_fixed_seed(p.subseed)
|
||||
|
||||
|
||||
def program_version():
|
||||
import launch
|
||||
|
||||
res = launch.git_tag()
|
||||
if res == "<none>":
|
||||
res = None
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
enable_hr = getattr(p, 'enable_hr', False)
|
||||
token_merging_ratio = p.get_token_merging_ratio()
|
||||
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
||||
|
||||
uses_ensd = opts.eta_noise_seed_delta != 0
|
||||
if uses_ensd:
|
||||
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
@@ -475,14 +511,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
}
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
@@ -491,6 +532,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||
p.override_settings.pop('sd_model_checkpoint', None)
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v)
|
||||
|
||||
@@ -500,15 +546,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
sd_models.apply_token_merging(p.sd_model, 0)
|
||||
|
||||
# restore opts to original state
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_model_checkpoint':
|
||||
sd_models.reload_model_weights()
|
||||
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
@@ -639,8 +687,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
||||
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||
step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
@@ -670,6 +720,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
p.batch_index = i
|
||||
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
@@ -706,9 +758,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
|
||||
if opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||
@@ -718,7 +770,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
|
||||
if opts.return_mask_composite:
|
||||
output_images.append(image_mask_composite)
|
||||
|
||||
@@ -751,7 +803,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
res = Processed(
|
||||
p,
|
||||
images_list=output_images,
|
||||
seed=p.all_seeds[0],
|
||||
info=infotext(),
|
||||
comments="".join(f"\n\n{comment}" for comment in comments),
|
||||
subseed=p.all_subseeds[0],
|
||||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
@@ -871,6 +932,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
self.is_hr_pass = True
|
||||
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
@@ -938,8 +1001,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@@ -1007,6 +1076,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.color_corrections = []
|
||||
imgs = []
|
||||
for img in self.init_images:
|
||||
|
||||
# Save init image
|
||||
if opts.save_init_img:
|
||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
if crop_region is None and self.resize_mode != 3:
|
||||
@@ -1093,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
devices.torch_gc()
|
||||
|
||||
return samples
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|
||||
|
||||
@@ -13,6 +13,8 @@ import modules.shared as shared
|
||||
current_task = None
|
||||
pending_tasks = {}
|
||||
finished_tasks = []
|
||||
recorded_results = []
|
||||
recorded_results_limit = 2
|
||||
|
||||
|
||||
def start_task(id_task):
|
||||
@@ -33,6 +35,12 @@ def finish_task(id_task):
|
||||
finished_tasks.pop(0)
|
||||
|
||||
|
||||
def record_results(id_task, res):
|
||||
recorded_results.append((id_task, res))
|
||||
if len(recorded_results) > recorded_results_limit:
|
||||
recorded_results.pop(0)
|
||||
|
||||
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
@@ -87,8 +95,20 @@ def progressapi(req: ProgressRequest):
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="png")
|
||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
||||
|
||||
if opts.live_previews_image_format == "png":
|
||||
# using optimize for large images takes an enormous amount of time
|
||||
if max(*image.size) <= 256:
|
||||
save_kwargs = {"optimize": True}
|
||||
else:
|
||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||
|
||||
else:
|
||||
save_kwargs = {}
|
||||
|
||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
else:
|
||||
live_preview = None
|
||||
@@ -97,3 +117,13 @@ def progressapi(req: ProgressRequest):
|
||||
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
||||
|
||||
def restore_progress(id_task):
|
||||
while id_task == current_task or id_task in pending_tasks:
|
||||
time.sleep(0.1)
|
||||
|
||||
res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
|
||||
|
||||
@@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
"""
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
l = [steps]
|
||||
res = [steps]
|
||||
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
tree.children[-1] = float(tree.children[-1])
|
||||
if tree.children[-1] < 1:
|
||||
tree.children[-1] *= steps
|
||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||
l.append(tree.children[-1])
|
||||
res.append(tree.children[-1])
|
||||
|
||||
def alternate(self, tree):
|
||||
l.extend(range(1, steps+1))
|
||||
res.extend(range(1, steps+1))
|
||||
|
||||
CollectSteps().visit(tree)
|
||||
return sorted(set(l))
|
||||
return sorted(set(res))
|
||||
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
def get_schedule(prompt):
|
||||
try:
|
||||
tree = schedule_parser.parse(prompt)
|
||||
except lark.exceptions.LarkError as e:
|
||||
except lark.exceptions.LarkError:
|
||||
if 0:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||
if current_step <= end_at:
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
tensors = []
|
||||
conds_list = []
|
||||
|
||||
for batch_no, composable_prompts in enumerate(c.batch):
|
||||
for composable_prompts in c.batch:
|
||||
conds_for_batch = []
|
||||
|
||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
||||
for composable_prompt in composable_prompts:
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
||||
if current_step <= end_at:
|
||||
for current, entry in enumerate(composable_prompt.schedules):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from realesrgan import RealESRGANer
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
|
||||
from modules import modelloader
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
@@ -17,13 +17,21 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||
from realesrgan import RealESRGANer # noqa: F401
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
|
||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||
for scaler in scalers:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||
if local_model_candidates:
|
||||
scaler.local_data_path = local_model_candidates[0]
|
||||
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
@@ -39,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.local_data_path):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
@@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
if info.local_data_path.startswith("http"):
|
||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
@@ -124,6 +134,6 @@ def get_realesrgan_models(scaler):
|
||||
),
|
||||
]
|
||||
return models
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import io
|
||||
import pickle
|
||||
import collections
|
||||
import sys
|
||||
@@ -12,11 +11,9 @@ import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
@@ -27,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
try:
|
||||
return TypedStorage(_internal=True)
|
||||
except TypeError:
|
||||
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
@@ -39,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
@@ -94,16 +95,16 @@ def check_pt(filename, extra_handler):
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
|
||||
@@ -32,27 +32,42 @@ class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
self.text_cond = text_cond
|
||||
""" Encoder hidden states of text conditioning from prompt"""
|
||||
|
||||
|
||||
self.text_uncond = text_uncond
|
||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
self.inner_model = inner_model
|
||||
"""Inner model reference used for denoising"""
|
||||
|
||||
|
||||
class AfterCFGCallbackParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
@@ -87,12 +102,14 @@ callback_map = dict(
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_cfg_denoised=[],
|
||||
callbacks_cfg_after_cfg=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
callbacks_on_reload=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -109,6 +126,14 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def app_reload_callback():
|
||||
for c in callback_map['callbacks_on_reload']:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_on_reload')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callback_map['callbacks_model_loaded']:
|
||||
try:
|
||||
@@ -177,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
||||
report_exception(c, 'cfg_denoised_callback')
|
||||
|
||||
|
||||
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||
for c in callback_map['callbacks_cfg_after_cfg']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_after_cfg_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
@@ -231,7 +264,7 @@ def add_callback(callbacks, fun):
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
@@ -254,6 +287,11 @@ def on_app_started(callback):
|
||||
add_callback(callback_map['callbacks_app_started'], callback)
|
||||
|
||||
|
||||
def on_before_reload(callback):
|
||||
"""register a function to be called just before the server reloads."""
|
||||
add_callback(callback_map['callbacks_on_reload'], callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument; this function is also called when the script is reloaded. """
|
||||
@@ -318,6 +356,14 @@ def on_cfg_denoised(callback):
|
||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||
|
||||
|
||||
def on_cfg_after_cfg(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||
The callback is called with one argument:
|
||||
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
|
||||
@@ -17,6 +17,9 @@ class PostprocessImageArgs:
|
||||
|
||||
|
||||
class Script:
|
||||
name = None
|
||||
"""script's internal name derived from title"""
|
||||
|
||||
filename = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
@@ -25,8 +28,8 @@ class Script:
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
@@ -38,6 +41,9 @@ class Script:
|
||||
various "Send to <X>" buttons when clicked
|
||||
"""
|
||||
|
||||
api_info = None
|
||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
@@ -163,7 +169,8 @@ class Script:
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
|
||||
need_tabname = self.show(True) == self.show(False)
|
||||
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
||||
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
||||
tabname = f"{tabkind}_" if need_tabname else ""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
|
||||
return f'script_{tabname}{title}_{item_id}'
|
||||
@@ -230,7 +237,7 @@ def load_scripts():
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for key, script_class in module.__dict__.items():
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
||||
@@ -294,9 +301,9 @@ class ScriptRunner:
|
||||
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
for script_data in auto_processing_scripts + scripts_data:
|
||||
script = script_data.script_class()
|
||||
script.filename = script_data.path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
|
||||
@@ -312,6 +319,8 @@ class ScriptRunner:
|
||||
self.selectable_scripts.append(script)
|
||||
|
||||
def setup_ui(self):
|
||||
import modules.api.models as api_models
|
||||
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
inputs = [None]
|
||||
@@ -326,9 +335,28 @@ class ScriptRunner:
|
||||
if controls is None:
|
||||
return
|
||||
|
||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||
api_args = []
|
||||
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
|
||||
api_args.append(arg_info)
|
||||
|
||||
script.api_info = api_models.ScriptInfo(
|
||||
name=script.name,
|
||||
is_img2img=script.is_img2img,
|
||||
is_alwayson=script.alwayson,
|
||||
args=api_args,
|
||||
)
|
||||
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
|
||||
@@ -491,7 +519,7 @@ class ScriptRunner:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
self.scripts[si] = script_class()
|
||||
self.scripts[si].filename = filename
|
||||
@@ -526,7 +554,7 @@ def add_classes_to_gradio_component(comp):
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
|
||||
@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
||||
def initialize_scripts(self, scripts_data):
|
||||
self.scripts = []
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
for script_data in scripts_data:
|
||||
script: ScriptPostprocessing = script_data.script_class()
|
||||
script.filename = script_data.path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, component), value in zip(script.controls.items(), script_args):
|
||||
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
script.process(pp, **process_args)
|
||||
|
||||
@@ -61,7 +61,7 @@ class DisableInitialization:
|
||||
if res is None:
|
||||
res = original(url, *args, local_files_only=False, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return original(url, *args, local_files_only=False, **kwargs)
|
||||
|
||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
|
||||
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules import devices, sd_hijack_optimizations, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
@@ -34,10 +34,10 @@ def apply_optimizations():
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
|
||||
optimization_method = None
|
||||
|
||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
@@ -92,12 +92,12 @@ def fix_checkpoint():
|
||||
def weighted_loss(sd_model, pred, target, mean=True):
|
||||
#Calculate the weight normally, but ignore the mean
|
||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
|
||||
|
||||
#Check if we have weights available
|
||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
if weight is not None:
|
||||
loss *= weight
|
||||
|
||||
|
||||
#Return the loss, as mean if specified
|
||||
return loss.mean() if mean else loss
|
||||
|
||||
@@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Temporarily append weights to a place accessible during loss calc
|
||||
sd_model._custom_loss_weight = w
|
||||
|
||||
|
||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
if not hasattr(sd_model, '_old_get_loss'):
|
||||
@@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
try:
|
||||
#Delete temporary weights if appended
|
||||
del sd_model._custom_loss_weight
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
#If we have an old loss function, reset the loss function to the original one
|
||||
if hasattr(sd_model, '_old_get_loss'):
|
||||
sd_model.get_loss = sd_model._old_get_loss
|
||||
@@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
|
||||
def undo_weighted_forward(sd_model):
|
||||
try:
|
||||
del sd_model.weighted_forward
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
@@ -216,6 +216,9 @@ class StableDiffusionModelHijack:
|
||||
self.comments = []
|
||||
|
||||
def get_prompt_lengths(self, text):
|
||||
if self.clip is None:
|
||||
return "-", "-"
|
||||
|
||||
_, token_count = self.clip.process_texts([text])
|
||||
|
||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
for fixes in self.hijack.fixes:
|
||||
for position, embedding in fixes:
|
||||
for _position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
|
||||
@@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
from ldm.models.diffusion.ddim import noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
@@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
c_in = {}
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
|
||||
|
||||
def should_hijack_ip2p(checkpoint_info):
|
||||
from modules import sd_models_config
|
||||
@@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
|
||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
||||
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||
|
||||
@@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
v_in = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
del q, k, v
|
||||
@@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k_in = k_in * self.scale
|
||||
|
||||
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
|
||||
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = r1.to(dtype)
|
||||
@@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
k = k * self.scale
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
r = einsum_op(q, k, v)
|
||||
r = r.to(dtype)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
@@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
@@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
@@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
@@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
@@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
@@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
@@ -18,7 +18,7 @@ class TorchHijackForUnet:
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
|
||||
@@ -2,6 +2,8 @@ import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
@@ -13,9 +15,9 @@ import ldm.modules.midas as midas
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@@ -45,20 +47,29 @@ class CheckpointInfo:
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
_, ext = os.path.splitext(self.filename)
|
||||
if ext.lower() == ".safetensors":
|
||||
try:
|
||||
self.metadata = read_metadata_from_safetensors(filename)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
for id in self.ids:
|
||||
checkpoint_alisases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
@@ -76,8 +87,7 @@ class CheckpointInfo:
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging, CLIPModel
|
||||
from transformers import logging, CLIPModel # noqa: F401
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
@@ -156,7 +166,7 @@ def model_hash(filename):
|
||||
|
||||
def select_checkpoint():
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
|
||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
@@ -228,7 +238,7 @@ def read_metadata_from_safetensors(filename):
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return res
|
||||
@@ -363,7 +373,7 @@ def enable_midas_autodownload():
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
mkdir(midas_path)
|
||||
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
@@ -395,13 +405,42 @@ def repair_config(sd_config):
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_sd_model(self):
|
||||
if self.sd_model is None:
|
||||
with self.lock:
|
||||
if self.sd_model is not None:
|
||||
return self.sd_model
|
||||
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model")
|
||||
print("", file=sys.stderr)
|
||||
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||
self.sd_model = None
|
||||
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
if model_data.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -430,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if sd_model is None:
|
||||
@@ -455,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
model_data.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
@@ -475,7 +514,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
sd_model = model_data.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
@@ -501,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return shared.sd_model
|
||||
return model_data.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
@@ -526,17 +564,15 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
from modules import devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if shared.sd_model:
|
||||
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
shared.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
@@ -544,4 +580,30 @@ def unload_model_weights(sd_model=None, info=None):
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
return sd_model
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
"""
|
||||
Applies speed and memory optimizations from tomesd.
|
||||
"""
|
||||
|
||||
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
||||
|
||||
if current_token_merging_ratio == token_merging_ratio:
|
||||
return
|
||||
|
||||
if current_token_merging_ratio > 0:
|
||||
tomesd.remove_patch(sd_model)
|
||||
|
||||
if token_merging_ratio > 0:
|
||||
tomesd.apply_patch(
|
||||
sd_model,
|
||||
ratio=token_merging_ratio,
|
||||
use_rand=False, # can cause issues with some samplers
|
||||
merge_attn=True,
|
||||
merge_crossattn=False,
|
||||
merge_mlp=False
|
||||
)
|
||||
|
||||
sd_model.applied_token_merged_ratio = token_merging_ratio
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
@@ -14,12 +14,18 @@ samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
def find_sampler_config(name):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
config = find_sampler_config(name)
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
sampler = config.constructor(model)
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx
|
||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
||||
elif approximation == 3:
|
||||
x_sample = sample * 1.5
|
||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
@@ -58,5 +62,34 @@ def store_latent(decoded):
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
def is_sampler_using_eta_noise_seed_delta(p):
|
||||
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||
|
||||
eta = p.eta
|
||||
|
||||
if eta is None and p.sampler is not None:
|
||||
eta = p.sampler.eta
|
||||
|
||||
if eta is None and sampler_config is not None:
|
||||
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
||||
|
||||
if eta == 0:
|
||||
return False
|
||||
|
||||
return sampler_config.options.get("uses_ensd", False)
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
if opts.randn_source == "CPU":
|
||||
import torchsde._brownian.brownian_interval
|
||||
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
@@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||
]
|
||||
@@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
if self.is_ddim:
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
else:
|
||||
self.eta = 0.0
|
||||
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import einops
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
@@ -9,25 +8,26 @@ from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
@@ -115,12 +115,21 @@ class CFGDenoiser(torch.nn.Module):
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
else:
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||
@@ -144,28 +153,39 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
|
||||
@@ -182,7 +202,7 @@ class TorchHijack:
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
@@ -190,7 +210,7 @@ class TorchHijack:
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
@@ -210,6 +230,7 @@ class KDiffusionSampler:
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
@@ -244,6 +265,7 @@ class KDiffusionSampler:
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
@@ -299,7 +321,7 @@ class KDiffusionSampler:
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
@@ -322,10 +344,11 @@ class KDiffusionSampler:
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
@@ -356,10 +379,11 @@ class KDiffusionSampler:
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import torch
|
||||
import safetensors.torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
@@ -88,10 +85,10 @@ def refresh_vae_list():
|
||||
|
||||
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
|
||||
if os.path.isfile(vae_location):
|
||||
return vae_location
|
||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||
for vae_file in vae_dict.values():
|
||||
if os.path.basename(vae_file).startswith(checkpoint_path):
|
||||
return vae_file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
88
modules/sd_vae_taesd.py
Normal file
88
modules/sd_vae_taesd.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Tiny AutoEncoder for Stable Diffusion
|
||||
(DNN for encoding / decoding SD's latent space)
|
||||
|
||||
https://github.com/madebyollin/taesd
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
|
||||
sd_vae_taesd = None
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
|
||||
class Clamp(nn.Module):
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
def decoder():
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.decoder = decoder()
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def download_model(model_path):
|
||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading TAESD decoder to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_taesd
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||
download_model(model_path)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
||||
@@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
|
||||
@@ -14,7 +13,8 @@ import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
|
||||
demo = None
|
||||
|
||||
@@ -39,6 +39,7 @@ restricted_opts = {
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
@@ -54,6 +55,21 @@ ui_reorder_categories = [
|
||||
"scripts",
|
||||
]
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/glass",
|
||||
"gradio/monochrome",
|
||||
"gradio/seafoam",
|
||||
"gradio/soft",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"gradio/dracula_test",
|
||||
"abidlabs/dracula_test",
|
||||
"abidlabs/pakistan",
|
||||
"dawood/microsoft_windows",
|
||||
"ysharma/steampunk"
|
||||
]
|
||||
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
@@ -95,8 +111,47 @@ class State:
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
server_start = None
|
||||
_server_command_signal = threading.Event()
|
||||
_server_command: str | None = None
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
# Compatibility getter for need_restart.
|
||||
return self.server_command == "restart"
|
||||
|
||||
@need_restart.setter
|
||||
def need_restart(self, value: bool) -> None:
|
||||
# Compatibility setter for need_restart.
|
||||
if value:
|
||||
self.server_command = "restart"
|
||||
|
||||
@property
|
||||
def server_command(self):
|
||||
return self._server_command
|
||||
|
||||
@server_command.setter
|
||||
def server_command(self, value: str | None) -> None:
|
||||
"""
|
||||
Set the server command to `value` and signal that it's been set.
|
||||
"""
|
||||
self._server_command = value
|
||||
self._server_command_signal.set()
|
||||
|
||||
def wait_for_server_command(self, timeout: float | None = None) -> str | None:
|
||||
"""
|
||||
Wait for server command to get set; return and clear the value and signal.
|
||||
"""
|
||||
if self._server_command_signal.wait(timeout):
|
||||
self._server_command_signal.clear()
|
||||
req = self._server_command
|
||||
self._server_command = None
|
||||
return req
|
||||
return None
|
||||
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
@@ -184,8 +239,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
@@ -194,9 +250,33 @@ class OptionInfo:
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
|
||||
self.comment_after = comment_after
|
||||
"""HTML text that will be added before label in UI"""
|
||||
|
||||
def link(self, label, url):
|
||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def js(self, label, js_func):
|
||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def info(self, info):
|
||||
self.comment_after += f"<span class='info'>({info})</span>"
|
||||
return self
|
||||
|
||||
def needs_restart(self):
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for k, v in options_dict.items():
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
|
||||
return options_dict
|
||||
@@ -225,7 +305,7 @@ options_templates = {}
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
@@ -244,15 +324,15 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
@@ -268,35 +348,37 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
@@ -318,19 +400,27 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
@@ -338,80 +428,93 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||
}))
|
||||
|
||||
@@ -424,9 +527,11 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
@@ -516,6 +621,10 @@ class Options:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
# 1.1.1 quicksettings list migration
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
||||
bad_settings = 0
|
||||
for k, v in self.data.items():
|
||||
info = self.data_labels.get(k, None)
|
||||
@@ -534,7 +643,9 @@ class Options:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
@@ -545,11 +656,11 @@ class Options:
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for k, item in settings_items:
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
@@ -574,13 +685,37 @@ class Options:
|
||||
return value
|
||||
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
|
||||
class Shared(sys.modules[__name__].__class__):
|
||||
"""
|
||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||
at program startup.
|
||||
"""
|
||||
|
||||
sd_model_val = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
import modules.sd_models
|
||||
|
||||
return modules.sd_models.model_data.get_sd_model()
|
||||
|
||||
@sd_model.setter
|
||||
def sd_model(self, value):
|
||||
import modules.sd_models
|
||||
|
||||
modules.sd_models.model_data.set_sd_model(value)
|
||||
|
||||
|
||||
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
||||
sys.modules[__name__].__class__ = Shared
|
||||
|
||||
settings_components = None
|
||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
@@ -594,12 +729,33 @@ latent_upscale_modes = {
|
||||
|
||||
sd_upscalers = []
|
||||
|
||||
sd_model = None
|
||||
|
||||
clip_model = None
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
gradio_theme = gr.themes.Base()
|
||||
|
||||
|
||||
def reload_gradio_theme(theme_name=None):
|
||||
global gradio_theme
|
||||
if not theme_name:
|
||||
theme_name = opts.gradio_theme
|
||||
|
||||
default_theme_args = dict(
|
||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||
)
|
||||
|
||||
if theme_name == "Default":
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
else:
|
||||
try:
|
||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
except Exception as e:
|
||||
errors.display(e, "changing gradio theme")
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
|
||||
|
||||
|
||||
class TotalTQDM:
|
||||
def __init__(self):
|
||||
@@ -657,3 +813,23 @@ def html(filename):
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def walk_files(path, allowed_extensions=None):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
if allowed_extensions is not None:
|
||||
allowed_extensions = set(allowed_extensions)
|
||||
|
||||
for root, _, files in os.walk(path, followlinks=True):
|
||||
for filename in files:
|
||||
if allowed_extensions is not None:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext not in allowed_extensions:
|
||||
continue
|
||||
|
||||
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||
continue
|
||||
|
||||
yield os.path.join(root, filename)
|
||||
|
||||
@@ -1,18 +1,9 @@
|
||||
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import os.path
|
||||
import typing
|
||||
import collections.abc as abc
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# Only import this when code is being type-checked, it doesn't have any effect at runtime
|
||||
from .processing import StableDiffusionProcessing
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
name: str
|
||||
@@ -52,7 +43,7 @@ class StyleDatabase:
|
||||
return
|
||||
|
||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
||||
reader = csv.DictReader(file)
|
||||
reader = csv.DictReader(file, skipinitialspace=True)
|
||||
for row in reader:
|
||||
# Support loading old CSV format with "name, text"-columns
|
||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||
@@ -72,16 +63,14 @@ class StyleDatabase:
|
||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||
|
||||
def save_styles(self, path: str) -> None:
|
||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||
fd, temp_path = tempfile.mkstemp(".csv")
|
||||
# Always keep a backup file around
|
||||
if os.path.exists(path):
|
||||
shutil.copy(path, f"{path}.bak")
|
||||
|
||||
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||
|
||||
# Always keep a backup file around
|
||||
if os.path.exists(path):
|
||||
shutil.move(path, path + ".bak")
|
||||
shutil.move(temp_path, path)
|
||||
|
||||
@@ -179,7 +179,7 @@ def efficient_dot_product_attention(
|
||||
chunk_idx,
|
||||
min(query_chunk_size, q_tokens)
|
||||
)
|
||||
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
@@ -201,14 +201,15 @@ def efficient_dot_product_attention(
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
|
||||
res = torch.zeros_like(query)
|
||||
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
||||
attn_scores = compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key=key,
|
||||
value=value,
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
)
|
||||
|
||||
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
||||
|
||||
return res
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from math import log, sqrt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL import ImageDraw
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
@@ -12,63 +10,64 @@ RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
results = []
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
results = []
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
return results
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
@@ -88,7 +87,7 @@ def focal_point(im, settings):
|
||||
corner_centroid = None
|
||||
if len(corner_points) > 0:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
@@ -100,7 +99,7 @@ def focal_point(im, settings):
|
||||
face_centroid = None
|
||||
if len(face_points) > 0:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
@@ -111,7 +110,7 @@ def focal_point(im, settings):
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
||||
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
@@ -119,7 +118,7 @@ def focal_point(im, settings):
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
||||
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
@@ -127,14 +126,14 @@ def focal_point(im, settings):
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
||||
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
@@ -185,7 +184,7 @@ def image_face_points(im, settings):
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if len(faces) > 0:
|
||||
@@ -262,10 +261,11 @@ def image_entropy(im):
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
@@ -283,59 +283,59 @@ def poi_average(pois, settings):
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
return w == h
|
||||
|
||||
|
||||
def download_and_cache_models(dirname):
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size//2,
|
||||
self.y - size//2,
|
||||
self.x + size//2,
|
||||
self.y + size//2
|
||||
]
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size // 2,
|
||||
self.y - size // 2,
|
||||
self.x + size // 2,
|
||||
self.y + size // 2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
||||
|
||||
@@ -72,7 +72,7 @@ class PersonalizedBase(Dataset):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
||||
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||
filename = os.path.basename(path)
|
||||
|
||||
if os.path.exists(text_filename):
|
||||
@@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
|
||||
weight = torch.ones(latent_sample.shape)
|
||||
else:
|
||||
weight = None
|
||||
|
||||
|
||||
if latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||
else:
|
||||
@@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
||||
return BatchLoaderRandom(batch)
|
||||
|
||||
@@ -2,10 +2,8 @@ import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import zlib
|
||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||
from fonts.ttf import Roboto
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import torch
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
@@ -17,7 +15,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
||||
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||
|
||||
def object_hook(self, d):
|
||||
if 'TORCHTENSOR' in d:
|
||||
@@ -136,11 +134,8 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||
image = srcimage.copy()
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
try:
|
||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
textfont = opts.font or Roboto
|
||||
except Exception:
|
||||
textfont = Roboto
|
||||
from modules.images import get_font
|
||||
textfont = get_font(fontsize)
|
||||
|
||||
factor = 1.5
|
||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||
|
||||
@@ -12,7 +12,7 @@ class LearnScheduleIterator:
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
try:
|
||||
for i, pair in enumerate(pairs):
|
||||
for pair in pairs:
|
||||
if not pair.strip():
|
||||
continue
|
||||
tmp = pair.split(':')
|
||||
@@ -32,8 +32,8 @@ class LearnScheduleIterator:
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
except (ValueError, AssertionError) as e:
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import os
|
||||
from PIL import Image, ImageOps
|
||||
import math
|
||||
import platform
|
||||
import sys
|
||||
import tqdm
|
||||
import time
|
||||
|
||||
from modules import paths, shared, images, deepbooru
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.textual_inversion import autocrop
|
||||
|
||||
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
@@ -19,7 +15,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
|
||||
if process_caption_deepbooru:
|
||||
deepbooru.model.start()
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||
|
||||
finally:
|
||||
|
||||
@@ -63,9 +59,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
|
||||
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
||||
|
||||
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
caption = f"{existing_caption} {caption}"
|
||||
elif params.preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
caption = f"{caption} {existing_caption}"
|
||||
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
@@ -129,9 +125,9 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
|
||||
default=None
|
||||
)
|
||||
return wh and center_crop(image, *wh)
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
@@ -161,7 +157,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
params.subindex = 0
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
img = Image.open(filename)
|
||||
img = ImageOps.exif_transpose(img)
|
||||
img = img.convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
@@ -172,7 +170,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
params.src = filename
|
||||
|
||||
existing_caption = None
|
||||
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
||||
existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
|
||||
if os.path.exists(existing_caption_filename):
|
||||
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
||||
existing_caption = file.read()
|
||||
@@ -223,6 +221,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
||||
process_default_resize = False
|
||||
|
||||
if process_keep_original_size:
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
@@ -30,7 +29,7 @@ textual_inversion_templates = {}
|
||||
def list_textual_inversion_templates():
|
||||
textual_inversion_templates.clear()
|
||||
|
||||
for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||
for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||
for fn in fns:
|
||||
path = os.path.join(root, fn)
|
||||
|
||||
@@ -69,7 +68,7 @@ class Embedding:
|
||||
'hash': self.checksum(),
|
||||
'optimizer_state_dict': self.optimizer_state_dict,
|
||||
}
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
torch.save(optimizer_saved_dict, f"{filename}.optim")
|
||||
|
||||
def checksum(self):
|
||||
if self.cached_checksum is not None:
|
||||
@@ -167,8 +166,7 @@ class EmbeddingDatabase:
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
@@ -199,7 +197,7 @@ class EmbeddingDatabase:
|
||||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for root, _, fns in os.walk(embdir.path, followlinks=True):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
@@ -216,7 +214,7 @@ class EmbeddingDatabase:
|
||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||
if not force_reload:
|
||||
need_reload = False
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
for embdir in self.embedding_dirs.values():
|
||||
if embdir.has_changed():
|
||||
need_reload = True
|
||||
break
|
||||
@@ -229,10 +227,16 @@ class EmbeddingDatabase:
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
for embdir in self.embedding_dirs.values():
|
||||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
|
||||
# using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
|
||||
sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
|
||||
self.word_embeddings.clear()
|
||||
self.word_embeddings.update(sorted_word_embeddings)
|
||||
|
||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||
self.previously_displayed_embeddings = displayed_embeddings
|
||||
@@ -319,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
|
||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||
|
||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
tensorboard_writer.add_scalar(tag=tag,
|
||||
scalar_value=value, global_step=step)
|
||||
|
||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||
# Convert a pil image to a torch tensor
|
||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||
len(pil_image.getbands()))
|
||||
img_tensor = img_tensor.permute((2, 0, 1))
|
||||
|
||||
|
||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||
|
||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||
@@ -398,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return embedding, filename
|
||||
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||
@@ -408,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
|
||||
@@ -431,11 +435,11 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||
if shared.opts.save_optimizer_state:
|
||||
optimizer_state_dict = None
|
||||
if os.path.exists(filename + '.optim'):
|
||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||
if os.path.exists(f"{filename}.optim"):
|
||||
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
|
||||
|
||||
if optimizer_state_dict is not None:
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
@@ -464,7 +468,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
try:
|
||||
sd_hijack_checkpoint.add()
|
||||
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
for _ in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
@@ -481,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
|
||||
if clip_grad:
|
||||
clip_grad_sched.step(embedding.step)
|
||||
|
||||
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if use_weight:
|
||||
@@ -509,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
||||
|
||||
if clip_grad:
|
||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||
|
||||
@@ -593,17 +597,17 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
data = torch.load(last_saved_file)
|
||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||
|
||||
title = "<{}>".format(data.get('name', '???'))
|
||||
title = f"<{data.get('name', '???')}>"
|
||||
|
||||
try:
|
||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
vectorSize = '?'
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
footer_left = checkpoint.model_name
|
||||
footer_mid = '[{}]'.format(checkpoint.shorthash)
|
||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||
footer_mid = f'[{checkpoint.shorthash}]'
|
||||
footer_right = f'{vectorSize}v {steps_done}s'
|
||||
|
||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import modules.scripts
|
||||
from modules import sd_samplers
|
||||
from modules import sd_samplers, processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
@@ -53,7 +50,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
processed = processing.process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
|
||||
377
modules/ui.py
377
modules/ui.py
@@ -1,29 +1,23 @@
|
||||
import html
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial, reduce
|
||||
from functools import reduce
|
||||
import warnings
|
||||
|
||||
import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
@@ -34,7 +28,6 @@ import modules.shared as shared
|
||||
import modules.styles
|
||||
import modules.textual_inversion.ui
|
||||
from modules import prompt_parser
|
||||
from modules.images import save_image
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.textual_inversion import textual_inversion
|
||||
@@ -81,6 +74,7 @@ apply_style_symbol = '\U0001f4cb' # 📋
|
||||
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||
switch_values_symbol = '\U000021C5' # ⇅
|
||||
restore_progress_symbol = '\U0001F300' # 🌀
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
@@ -92,13 +86,6 @@ def send_gradio_gallery_to_image(x):
|
||||
return None
|
||||
return image_from_url_text(x[0])
|
||||
|
||||
def visit(x, func, path=""):
|
||||
if hasattr(x, 'children'):
|
||||
for c in x.children:
|
||||
visit(c, func, path)
|
||||
elif x.label is not None:
|
||||
func(path + "/" + str(x.label), x)
|
||||
|
||||
|
||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||
if name is None:
|
||||
@@ -127,6 +114,16 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
|
||||
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
||||
|
||||
|
||||
def resize_from_to_html(width, height, scale_by):
|
||||
target_width = int(width * scale_by)
|
||||
target_height = int(height * scale_by)
|
||||
|
||||
if not target_width or not target_height:
|
||||
return "no image selected"
|
||||
|
||||
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
||||
|
||||
|
||||
def apply_styles(prompt, prompt_neg, styles):
|
||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
|
||||
@@ -152,7 +149,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
||||
img = Image.open(image)
|
||||
filename = os.path.basename(image)
|
||||
left, _ = os.path.splitext(filename)
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
|
||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
|
||||
|
||||
return [gr.update(), None]
|
||||
|
||||
@@ -168,29 +165,29 @@ def interrogate_deepbooru(image):
|
||||
|
||||
|
||||
def create_seed_inputs(target_interface):
|
||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
||||
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
||||
seed.style(container=False)
|
||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
||||
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
||||
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
||||
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
|
||||
|
||||
# Components to show/hide based on the 'Extra' checkbox
|
||||
seed_extras = []
|
||||
|
||||
with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
|
||||
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
||||
seed_extras.append(seed_extra_row_1)
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
||||
subseed.style(container=False)
|
||||
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
|
||||
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
||||
|
||||
with FormRow(visible=False) as seed_extra_row_2:
|
||||
seed_extras.append(seed_extra_row_2)
|
||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
|
||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
|
||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
|
||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
|
||||
|
||||
target_interface_state = gr.Textbox(target_interface, visible=False)
|
||||
random_seed.click(fn=None, _js="setRandomSeed", show_progress=False, inputs=[target_interface_state], outputs=[])
|
||||
@@ -233,7 +230,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
all_seeds = gen_info.get('all_seeds', [-1])
|
||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
except json.decoder.JSONDecodeError:
|
||||
if gen_info_string != '':
|
||||
print("Error parsing JSON generation info:", file=sys.stderr)
|
||||
print(gen_info_string, file=sys.stderr)
|
||||
@@ -313,6 +310,7 @@ def create_toprow(is_img2img):
|
||||
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||
restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||
|
||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
@@ -330,7 +328,7 @@ def create_toprow(is_img2img):
|
||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
||||
|
||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
|
||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
@@ -409,7 +407,7 @@ def create_sampler_and_steps_selection(choices, tabname):
|
||||
def ordered_ui_categories():
|
||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
|
||||
|
||||
for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||
for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||
yield category
|
||||
|
||||
|
||||
@@ -447,7 +445,7 @@ def create_ui():
|
||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
@@ -469,7 +467,7 @@ def create_ui():
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
@@ -579,6 +577,19 @@ def create_ui():
|
||||
|
||||
res_switch_btn.click(fn=None, _js="switchWidthHeightTxt2Img", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
_js="restoreProgressTxt2img",
|
||||
inputs=[dummy_component],
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
@@ -647,7 +658,7 @@ def create_ui():
|
||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
|
||||
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
@@ -674,6 +685,8 @@ def create_ui():
|
||||
copy_image_buttons.append((button, name, elem))
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img"):
|
||||
img2img_selected_tab = gr.State(0)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
||||
add_copy_image_controls('img2img', init_img)
|
||||
@@ -707,8 +720,8 @@ def create_ui():
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(
|
||||
f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||
f"{hidden}</p>"
|
||||
)
|
||||
@@ -716,6 +729,11 @@ def create_ui():
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
|
||||
for i, tab in enumerate(img2img_tabs):
|
||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||
|
||||
def copy_image(img):
|
||||
if isinstance(img, dict) and 'image' in img:
|
||||
return img['image']
|
||||
@@ -730,7 +748,7 @@ def create_ui():
|
||||
)
|
||||
button.click(
|
||||
fn=lambda: None,
|
||||
_js="switch_to_"+name.replace(" ", "_"),
|
||||
_js=f"switch_to_{name.replace(' ', '_')}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
@@ -745,11 +763,44 @@ def create_ui():
|
||||
elif category == "dimensions":
|
||||
with FormRow():
|
||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
selected_scale_tab = gr.State(value=0)
|
||||
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
with gr.Tabs():
|
||||
with gr.Tab(label="Resize to") as tab_scale_to:
|
||||
with FormRow():
|
||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
|
||||
with gr.Tab(label="Resize by") as tab_scale_by:
|
||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||
|
||||
with FormRow():
|
||||
scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
|
||||
gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
|
||||
button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
|
||||
|
||||
on_change_args = dict(
|
||||
fn=resize_from_to_html,
|
||||
_js="currentImg2imgSourceResolution",
|
||||
inputs=[dummy_component, dummy_component, scale_by],
|
||||
outputs=scale_by_html,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
scale_by.release(**on_change_args)
|
||||
button_update_resize_to.click(**on_change_args)
|
||||
|
||||
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||
for component in [init_img, sketch]:
|
||||
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
||||
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
@@ -760,7 +811,7 @@ def create_ui():
|
||||
with FormGroup():
|
||||
with FormRow():
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||
|
||||
elif category == "seed":
|
||||
@@ -807,7 +858,7 @@ def create_ui():
|
||||
def select_img2img_tab(tab):
|
||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||
|
||||
for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
|
||||
for i, elem in enumerate(img2img_tabs):
|
||||
elem.select(
|
||||
fn=lambda tab=i: select_img2img_tab(tab),
|
||||
inputs=[],
|
||||
@@ -860,8 +911,10 @@ def create_ui():
|
||||
denoising_strength,
|
||||
seed,
|
||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||
selected_scale_tab,
|
||||
height,
|
||||
width,
|
||||
scale_by,
|
||||
resize_mode,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
@@ -900,6 +953,19 @@ def create_ui():
|
||||
|
||||
res_switch_btn.click(fn=None, _js="switchWidthHeightImg2Img", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
restore_progress_button.click(
|
||||
fn=progress.restore_progress,
|
||||
_js="restoreProgressImg2img",
|
||||
inputs=[dummy_component],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
img2img_interrogate.click(
|
||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||
**interrogate_args,
|
||||
@@ -1021,8 +1087,9 @@ def create_ui():
|
||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
||||
|
||||
with FormRow():
|
||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
||||
|
||||
with FormRow():
|
||||
with gr.Column():
|
||||
@@ -1050,7 +1117,7 @@ def create_ui():
|
||||
with gr.Row(variant="compact").style(equal_height=False):
|
||||
with gr.Tabs(elem_id="train_tabs"):
|
||||
|
||||
with gr.Tab(label="Create embedding"):
|
||||
with gr.Tab(label="Create embedding", id="create_embedding"):
|
||||
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
|
||||
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
|
||||
@@ -1063,7 +1130,7 @@ def create_ui():
|
||||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
|
||||
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
||||
@@ -1081,7 +1148,7 @@ def create_ui():
|
||||
with gr.Column():
|
||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
||||
|
||||
with gr.Tab(label="Preprocess images"):
|
||||
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
||||
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
||||
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
||||
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
||||
@@ -1089,6 +1156,7 @@ def create_ui():
|
||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
||||
|
||||
with gr.Row():
|
||||
process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
|
||||
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
|
||||
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
|
||||
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
|
||||
@@ -1105,7 +1173,7 @@ def create_ui():
|
||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
||||
|
||||
|
||||
with gr.Column(visible=False) as process_multicrop_col:
|
||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
||||
with gr.Row():
|
||||
@@ -1117,7 +1185,7 @@ def create_ui():
|
||||
with gr.Row():
|
||||
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
||||
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
@@ -1146,21 +1214,21 @@ def create_ui():
|
||||
)
|
||||
|
||||
def get_textual_inversion_template_names():
|
||||
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
||||
return sorted(textual_inversion.textual_inversion_templates)
|
||||
|
||||
with gr.Tab(label="Train"):
|
||||
with gr.Tab(label="Train", id="train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with FormRow():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
||||
|
||||
with FormRow():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||
|
||||
|
||||
with FormRow():
|
||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||
@@ -1206,8 +1274,8 @@ def create_ui():
|
||||
|
||||
with gr.Column(elem_id='ti_gallery_container'):
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||
gr.HTML(elem_id="ti_progress", value="")
|
||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||
|
||||
create_embedding.click(
|
||||
@@ -1255,6 +1323,7 @@ def create_ui():
|
||||
process_width,
|
||||
process_height,
|
||||
preprocess_txt_action,
|
||||
process_keep_original_size,
|
||||
process_flip,
|
||||
process_split,
|
||||
process_caption,
|
||||
@@ -1377,23 +1446,25 @@ def create_ui():
|
||||
elif t == bool:
|
||||
comp = gr.Checkbox
|
||||
else:
|
||||
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
||||
raise Exception(f'bad options item type: {t} for key {key}')
|
||||
|
||||
elem_id = "setting_"+key
|
||||
elem_id = f"setting_{key}"
|
||||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
with FormRow():
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||
else:
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
|
||||
return res
|
||||
|
||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||
|
||||
components = []
|
||||
component_dict = {}
|
||||
shared.settings_components = component_dict
|
||||
@@ -1440,7 +1511,7 @@ def create_ui():
|
||||
|
||||
result = gr.HTML(elem_id="settings_result")
|
||||
|
||||
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
|
||||
quicksettings_names = opts.quicksettings_list
|
||||
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
|
||||
|
||||
quicksettings_list = []
|
||||
@@ -1460,7 +1531,7 @@ def create_ui():
|
||||
current_tab.__exit__()
|
||||
|
||||
gr.Group()
|
||||
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
|
||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||
current_tab.__enter__()
|
||||
current_row = gr.Column(variant='compact')
|
||||
current_row.__enter__()
|
||||
@@ -1481,7 +1552,10 @@ def create_ui():
|
||||
current_row.__exit__()
|
||||
current_tab.__exit__()
|
||||
|
||||
with gr.TabItem("Actions"):
|
||||
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
||||
loadsave.create_ui()
|
||||
|
||||
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
@@ -1489,11 +1563,11 @@ def create_ui():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
|
||||
with gr.TabItem("Licenses"):
|
||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
|
||||
|
||||
|
||||
def unload_sd_weights():
|
||||
modules.sd_models.unload_model_weights()
|
||||
@@ -1537,12 +1611,8 @@ def create_ui():
|
||||
outputs=[]
|
||||
)
|
||||
|
||||
def request_restart():
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
fn=request_restart,
|
||||
fn=shared.state.request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
@@ -1554,7 +1624,7 @@ def create_ui():
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(train_interface, "Train", "ti"),
|
||||
(train_interface, "Train", "train"),
|
||||
]
|
||||
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
@@ -1567,23 +1637,36 @@ def create_ui():
|
||||
for _interface, label, _ifid in interfaces:
|
||||
shared.tab_names.append(label)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||
for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||
component = create_setting_component(k, is_quicksettings=True)
|
||||
component_dict[k] = component
|
||||
|
||||
parameters_copypaste.connect_paste_params_buttons()
|
||||
|
||||
with gr.Tabs(elem_id="tabs") as tabs:
|
||||
for interface, label, ifid in interfaces:
|
||||
tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
|
||||
sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
|
||||
|
||||
for interface, label, ifid in sorted_interfaces:
|
||||
if label in shared.opts.hidden_tabs:
|
||||
continue
|
||||
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
||||
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
|
||||
interface.render()
|
||||
|
||||
for interface, _label, ifid in interfaces:
|
||||
if ifid in ["extensions", "settings"]:
|
||||
continue
|
||||
|
||||
loadsave.add_block(interface, ifid)
|
||||
|
||||
loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
|
||||
|
||||
loadsave.setup_ui()
|
||||
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
|
||||
footer = shared.html("footer.html")
|
||||
footer = footer.format(versions=versions_html())
|
||||
@@ -1596,22 +1679,21 @@ def create_ui():
|
||||
outputs=[text_settings, result],
|
||||
)
|
||||
|
||||
for i, k, item in quicksettings_list:
|
||||
for _i, k, _item in quicksettings_list:
|
||||
component = component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
component.change(
|
||||
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||
change_handler(
|
||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
||||
inputs=[],
|
||||
outputs=[image_cfg_scale],
|
||||
)
|
||||
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||
text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
@@ -1660,6 +1742,7 @@ def create_ui():
|
||||
config_source,
|
||||
bake_in_vae,
|
||||
discard_weights,
|
||||
save_metadata,
|
||||
],
|
||||
outputs=[
|
||||
primary_model_name,
|
||||
@@ -1670,82 +1753,8 @@ def create_ui():
|
||||
]
|
||||
)
|
||||
|
||||
ui_config_file = cmd_opts.ui_config_file
|
||||
ui_settings = {}
|
||||
settings_count = len(ui_settings)
|
||||
error_loading = False
|
||||
|
||||
try:
|
||||
if os.path.exists(ui_config_file):
|
||||
with open(ui_config_file, "r", encoding="utf8") as file:
|
||||
ui_settings = json.load(file)
|
||||
except Exception:
|
||||
error_loading = True
|
||||
print("Error loading settings:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def loadsave(path, x):
|
||||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = path + "/" + field
|
||||
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
saved_value = ui_settings.get(key, None)
|
||||
if saved_value is None:
|
||||
ui_settings[key] = getattr(obj, field)
|
||||
elif condition and not condition(saved_value):
|
||||
pass
|
||||
|
||||
# this warning is generally not useful;
|
||||
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
||||
else:
|
||||
setattr(obj, field, saved_value)
|
||||
if init_field is not None:
|
||||
init_field(saved_value)
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
|
||||
if type(x) == gr.Slider:
|
||||
apply_field(x, 'value')
|
||||
apply_field(x, 'minimum')
|
||||
apply_field(x, 'maximum')
|
||||
apply_field(x, 'step')
|
||||
|
||||
if type(x) == gr.Radio:
|
||||
apply_field(x, 'value', lambda val: val in x.choices)
|
||||
|
||||
if type(x) == gr.Checkbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Textbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Number:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Dropdown:
|
||||
def check_dropdown(val):
|
||||
if getattr(x, 'multiselect', False):
|
||||
return all([value in x.choices for value in val])
|
||||
else:
|
||||
return val in x.choices
|
||||
|
||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||
|
||||
visit(txt2img_interface, loadsave, "txt2img")
|
||||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||
visit(train_interface, loadsave, "train")
|
||||
|
||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||
json.dump(ui_settings, file, indent=4)
|
||||
loadsave.dump_defaults()
|
||||
demo.ui_loadsave = loadsave
|
||||
|
||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||
interp_description.value = update_interp_description(interp_method.value)
|
||||
@@ -1763,12 +1772,11 @@ def webpath(fn):
|
||||
|
||||
|
||||
def javascript_html():
|
||||
script_js = os.path.join(script_path, "script.js")
|
||||
head = f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
# Ensure localization is in `window` before scripts
|
||||
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
||||
|
||||
inline = f"{localization.localization_js(shared.opts.localization)};"
|
||||
if cmd_opts.theme is not None:
|
||||
inline += f"set_theme('{cmd_opts.theme}');"
|
||||
script_js = os.path.join(script_path, "script.js")
|
||||
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
||||
|
||||
for script in modules.scripts.list_scripts("javascript", ".js"):
|
||||
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
||||
@@ -1776,7 +1784,8 @@ def javascript_html():
|
||||
for script in modules.scripts.list_scripts("javascript", ".mjs"):
|
||||
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
||||
|
||||
head += f'<script type="text/javascript">{inline}</script>\n'
|
||||
if cmd_opts.theme:
|
||||
head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
|
||||
|
||||
return head
|
||||
|
||||
@@ -1823,7 +1832,7 @@ def versions_html():
|
||||
|
||||
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
||||
commit = launch.commit_hash()
|
||||
short_commit = commit[0:8]
|
||||
tag = launch.git_tag()
|
||||
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
@@ -1832,15 +1841,31 @@ def versions_html():
|
||||
xformers_version = "N/A"
|
||||
|
||||
return f"""
|
||||
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
||||
 • 
|
||||
python: <span title="{sys.version}">{python_version}</span>
|
||||
•
|
||||
 • 
|
||||
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
||||
•
|
||||
 • 
|
||||
xformers: {xformers_version}
|
||||
•
|
||||
 • 
|
||||
gradio: {gr.__version__}
|
||||
•
|
||||
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
|
||||
•
|
||||
 • 
|
||||
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||
"""
|
||||
|
||||
|
||||
def setup_ui_api(app):
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
class QuicksettingsHint(BaseModel):
|
||||
name: str = Field(title="Name of the quicksettings field")
|
||||
label: str = Field(title="Label of the quicksettings field")
|
||||
|
||||
def quicksettings_hint():
|
||||
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
||||
|
||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
||||
|
||||
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||
|
||||
@@ -125,7 +125,7 @@ Requested path was: {f}
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
|
||||
@@ -62,3 +62,13 @@ class DropdownMulti(FormComponent, gr.Dropdown):
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
||||
|
||||
class DropdownEditable(FormComponent, gr.Dropdown):
|
||||
"""Same as gr.Dropdown but allows editing value"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(allow_custom_value=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import os.path
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
import git
|
||||
@@ -11,10 +13,12 @@ import html
|
||||
import shutil
|
||||
import errno
|
||||
|
||||
from modules import extensions, shared, paths
|
||||
from modules import extensions, shared, paths, config_states
|
||||
from modules.paths_internal import config_states_dir
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
STYLE_PRIMARY = ' style="color: var(--primary-400)"'
|
||||
|
||||
|
||||
def check_access():
|
||||
@@ -30,6 +34,9 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
||||
update = json.loads(update_list)
|
||||
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||
|
||||
if update:
|
||||
save_config_state("Backup (pre-update)")
|
||||
|
||||
update = set(update)
|
||||
|
||||
for ext in extensions.extensions:
|
||||
@@ -45,9 +52,47 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.disable_all_extensions = disable_all
|
||||
shared.opts.save(shared.config_filename)
|
||||
shared.state.request_restart()
|
||||
|
||||
shared.state.interrupt()
|
||||
shared.state.need_restart = True
|
||||
|
||||
def save_config_state(name):
|
||||
current_config_state = config_states.get_config()
|
||||
if not name:
|
||||
name = "Config"
|
||||
current_config_state["name"] = name
|
||||
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
|
||||
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||
print(f"Saving backup of webui/extension state to {filename}.")
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
json.dump(current_config_state, f)
|
||||
config_states.list_config_states()
|
||||
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||
return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
|
||||
|
||||
|
||||
def restore_config_state(confirmed, config_state_name, restore_type):
|
||||
if config_state_name == "Current":
|
||||
return "<span>Select a config to restore from.</span>"
|
||||
if not confirmed:
|
||||
return "<span>Cancelled.</span>"
|
||||
|
||||
check_access()
|
||||
|
||||
config_state = config_states.all_config_states[config_state_name]
|
||||
|
||||
print(f"*** Restoring webui state from backup: {restore_type} ***")
|
||||
|
||||
if restore_type == "extensions" or restore_type == "both":
|
||||
shared.opts.restore_config_state_file = config_state["filepath"]
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
if restore_type == "webui" or restore_type == "both":
|
||||
config_states.restore_webui_config(config_state)
|
||||
|
||||
shared.state.request_restart()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def check_updates(id_task, disable_list):
|
||||
@@ -76,6 +121,16 @@ def check_updates(id_task, disable_list):
|
||||
return extension_table(), ""
|
||||
|
||||
|
||||
def make_commit_link(commit_hash, remote, text=None):
|
||||
if text is None:
|
||||
text = commit_hash[:8]
|
||||
if remote.startswith("https://github.com/"):
|
||||
href = os.path.join(remote, "commit", commit_hash)
|
||||
return f'<a href="{href}" target="_blank">{text}</a>'
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def extension_table():
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="extensions">
|
||||
@@ -83,7 +138,9 @@ def extension_table():
|
||||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Extension version">Version</abbr></th>
|
||||
<th>Branch</th>
|
||||
<th>Version</th>
|
||||
<th>Date</th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
@@ -91,6 +148,7 @@ def extension_table():
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext: extensions.Extension
|
||||
ext.read_info_from_repo()
|
||||
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
@@ -102,13 +160,19 @@ def extension_table():
|
||||
|
||||
style = ""
|
||||
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
||||
style = ' style="color: var(--primary-400)"'
|
||||
style = STYLE_PRIMARY
|
||||
|
||||
version_link = ext.version
|
||||
if ext.commit_hash and ext.remote:
|
||||
version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td>{remote}</td>
|
||||
<td>{ext.version}</td>
|
||||
<td>{ext.branch}</td>
|
||||
<td>{version_link}</td>
|
||||
<td>{time.asctime(time.gmtime(ext.commit_date))}</td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
@@ -121,6 +185,133 @@ def extension_table():
|
||||
return code
|
||||
|
||||
|
||||
def update_config_states_table(state_name):
|
||||
if state_name == "Current":
|
||||
config_state = config_states.get_config()
|
||||
else:
|
||||
config_state = config_states.all_config_states[state_name]
|
||||
|
||||
config_name = config_state.get("name", "Config")
|
||||
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
||||
filepath = config_state.get("filepath", "<unknown>")
|
||||
|
||||
code = f"""<!-- {time.time()} -->"""
|
||||
|
||||
webui_remote = config_state["webui"]["remote"] or ""
|
||||
webui_branch = config_state["webui"]["branch"]
|
||||
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
||||
webui_commit_date = config_state["webui"]["commit_date"]
|
||||
if webui_commit_date:
|
||||
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
|
||||
else:
|
||||
webui_commit_date = "<unknown>"
|
||||
|
||||
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
|
||||
commit_link = make_commit_link(webui_commit_hash, webui_remote)
|
||||
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
|
||||
|
||||
current_webui = config_states.get_webui_config()
|
||||
|
||||
style_remote = ""
|
||||
style_branch = ""
|
||||
style_commit = ""
|
||||
if current_webui["remote"] != webui_remote:
|
||||
style_remote = STYLE_PRIMARY
|
||||
if current_webui["branch"] != webui_branch:
|
||||
style_branch = STYLE_PRIMARY
|
||||
if current_webui["commit_hash"] != webui_commit_hash:
|
||||
style_commit = STYLE_PRIMARY
|
||||
|
||||
code += f"""<h2>Config Backup: {config_name}</h2>
|
||||
<div><b>Filepath:</b> {filepath}</div>
|
||||
<div><b>Created at:</b> {created_date}</div>"""
|
||||
|
||||
code += f"""<h2>WebUI State</h2>
|
||||
<table id="config_state_webui">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>URL</th>
|
||||
<th>Branch</th>
|
||||
<th>Commit</th>
|
||||
<th>Date</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><label{style_remote}>{remote}</label></td>
|
||||
<td><label{style_branch}>{webui_branch}</label></td>
|
||||
<td><label{style_commit}>{commit_link}</label></td>
|
||||
<td><label{style_commit}>{date_link}</label></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
code += """<h2>Extension State</h2>
|
||||
<table id="config_state_extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Extension</th>
|
||||
<th>URL</th>
|
||||
<th>Branch</th>
|
||||
<th>Commit</th>
|
||||
<th>Date</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
ext_map = {ext.name: ext for ext in extensions.extensions}
|
||||
|
||||
for ext_name, ext_conf in config_state["extensions"].items():
|
||||
ext_remote = ext_conf["remote"] or ""
|
||||
ext_branch = ext_conf["branch"] or "<unknown>"
|
||||
ext_enabled = ext_conf["enabled"]
|
||||
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
|
||||
ext_commit_date = ext_conf["commit_date"]
|
||||
if ext_commit_date:
|
||||
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
|
||||
else:
|
||||
ext_commit_date = "<unknown>"
|
||||
|
||||
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
|
||||
commit_link = make_commit_link(ext_commit_hash, ext_remote)
|
||||
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
|
||||
|
||||
style_enabled = ""
|
||||
style_remote = ""
|
||||
style_branch = ""
|
||||
style_commit = ""
|
||||
if ext_name in ext_map:
|
||||
current_ext = ext_map[ext_name]
|
||||
current_ext.read_info_from_repo()
|
||||
if current_ext.enabled != ext_enabled:
|
||||
style_enabled = STYLE_PRIMARY
|
||||
if current_ext.remote != ext_remote:
|
||||
style_remote = STYLE_PRIMARY
|
||||
if current_ext.branch != ext_branch:
|
||||
style_branch = STYLE_PRIMARY
|
||||
if current_ext.commit_hash != ext_commit_hash:
|
||||
style_commit = STYLE_PRIMARY
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
||||
<td><label{style_remote}>{remote}</label></td>
|
||||
<td><label{style_branch}>{ext_branch}</label></td>
|
||||
<td><label{style_commit}>{commit_link}</label></td>
|
||||
<td><label{style_commit}>{date_link}</label></td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def normalize_git_url(url):
|
||||
if url is None:
|
||||
return ""
|
||||
@@ -129,7 +320,7 @@ def normalize_git_url(url):
|
||||
return url
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
def install_extension_from_url(dirname, url, branch_name=None):
|
||||
check_access()
|
||||
|
||||
assert url, 'No URL specified'
|
||||
@@ -150,10 +341,17 @@ def install_extension_from_url(dirname, url):
|
||||
|
||||
try:
|
||||
shutil.rmtree(tmpdir, True)
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
if not branch_name:
|
||||
# if no branch is specified, use the default branch
|
||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
else:
|
||||
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
|
||||
repo.remote().fetch()
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
try:
|
||||
os.rename(tmpdir, target_dir)
|
||||
except OSError as err:
|
||||
@@ -272,7 +470,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
|
||||
|
||||
"""
|
||||
|
||||
for tag in [x for x in extension_tags if x not in tags]:
|
||||
@@ -289,12 +487,21 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
||||
return code, list(tags)
|
||||
|
||||
|
||||
def preload_extensions_git_metadata():
|
||||
for extension in extensions.extensions:
|
||||
extension.read_info_from_repo()
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
config_states.list_config_states()
|
||||
|
||||
threading.Thread(target=preload_extensions_git_metadata).start()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
with gr.Tabs(elem_id="tabs_extensions"):
|
||||
with gr.TabItem("Installed", id="installed"):
|
||||
|
||||
with gr.Row(elem_id="extensions_installed_top"):
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
@@ -311,7 +518,8 @@ def create_ui():
|
||||
</span>
|
||||
"""
|
||||
info = gr.HTML(html)
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
extensions_table = gr.HTML('Loading...')
|
||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
||||
|
||||
apply.click(
|
||||
fn=apply_and_restart,
|
||||
@@ -327,7 +535,7 @@ def create_ui():
|
||||
outputs=[extensions_table, info],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
with gr.TabItem("Available", id="available"):
|
||||
with gr.Row():
|
||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
|
||||
@@ -338,9 +546,9 @@ def create_ui():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
@@ -374,16 +582,43 @@ def create_ui():
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
with gr.TabItem("Install from URL", id="install_from_url"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
install_button = gr.Button(value="Install", variant="primary")
|
||||
install_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
install_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
outputs=[extensions_table, install_result],
|
||||
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[install_dirname, install_url, install_branch],
|
||||
outputs=[install_url, extensions_table, install_result],
|
||||
)
|
||||
|
||||
with gr.TabItem("Backup/Restore"):
|
||||
with gr.Row(elem_id="extensions_backup_top_row"):
|
||||
config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
|
||||
modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
|
||||
config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
|
||||
config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
|
||||
with gr.Row(elem_id="extensions_backup_top_row2"):
|
||||
config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
|
||||
config_save_button = gr.Button(value="Save Current Config")
|
||||
|
||||
config_states_info = gr.HTML("")
|
||||
config_states_table = gr.HTML("Loading...")
|
||||
ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
|
||||
|
||||
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
|
||||
|
||||
config_states_list.change(
|
||||
fn=update_config_states_table,
|
||||
inputs=[config_states_list],
|
||||
outputs=[config_states_table],
|
||||
)
|
||||
|
||||
|
||||
return ui
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import glob
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from PIL import PngImagePlugin
|
||||
|
||||
from modules import shared
|
||||
from modules.images import read_info_from_image
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
@@ -27,11 +26,11 @@ def register_page(page):
|
||||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".webp"):
|
||||
if ext not in (".png", ".jpg", ".jpeg", ".webp"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
|
||||
|
||||
# would profit from returning 304
|
||||
@@ -69,7 +68,9 @@ class ExtraNetworksPage:
|
||||
pass
|
||||
|
||||
def link_preview(self, filename):
|
||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
||||
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
||||
mtime = os.path.getmtime(filename)
|
||||
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
||||
|
||||
def search_terms_from_path(self, filename, possible_directories=None):
|
||||
abspath = os.path.abspath(filename)
|
||||
@@ -89,19 +90,25 @@ class ExtraNetworksPage:
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
for root, dirs, _ in os.walk(parentdir, followlinks=True):
|
||||
for dirname in dirs:
|
||||
x = os.path.join(root, dirname)
|
||||
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
|
||||
is_empty = len(os.listdir(x)) == 0
|
||||
if not is_empty and not subdir.endswith("/"):
|
||||
subdir = subdir + "/"
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
subdirs[subdir] = 1
|
||||
is_empty = len(os.listdir(x)) == 0
|
||||
if not is_empty and not subdir.endswith("/"):
|
||||
subdir = subdir + "/"
|
||||
|
||||
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
||||
continue
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
@@ -143,6 +150,10 @@ class ExtraNetworksPage:
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
"""
|
||||
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
||||
"""
|
||||
|
||||
preview = item.get("preview", None)
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
@@ -157,8 +168,26 @@ class ExtraNetworksPage:
|
||||
if metadata:
|
||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||
|
||||
local_path = ""
|
||||
filename = item.get("filename", "")
|
||||
for reldir in self.allowed_directories_for_previews():
|
||||
absdir = os.path.abspath(reldir)
|
||||
|
||||
if filename.startswith(absdir):
|
||||
local_path = filename[len(absdir):]
|
||||
|
||||
# if this is true, the item must not be shown in the default view, and must instead only be
|
||||
# shown when searching for it
|
||||
if shared.opts.extra_networks_hidden_models == "Always":
|
||||
search_only = False
|
||||
else:
|
||||
search_only = "/." in local_path or "\\." in local_path
|
||||
|
||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||
return ""
|
||||
|
||||
args = {
|
||||
"style": f"'{height}{width}{background_image}'",
|
||||
"style": f"'display: none; {height}{width}{background_image}'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": json.dumps(tabname),
|
||||
"local_preview": json.dumps(item["local_preview"]),
|
||||
@@ -168,6 +197,7 @@ class ExtraNetworksPage:
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
"metadata_button": metadata_button,
|
||||
"search_only": " search_only" if search_only else "",
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
@@ -177,7 +207,7 @@ class ExtraNetworksPage:
|
||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||
"""
|
||||
|
||||
preview_extensions = ["png", "jpg", "webp"]
|
||||
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
||||
if shared.opts.samples_format not in preview_extensions:
|
||||
preview_extensions.append(shared.opts.samples_format)
|
||||
|
||||
@@ -209,6 +239,11 @@ def intialize():
|
||||
class ExtraNetworksUi:
|
||||
def __init__(self):
|
||||
self.pages = None
|
||||
"""gradio HTML components related to extra networks' pages"""
|
||||
|
||||
self.page_contents = None
|
||||
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
|
||||
|
||||
self.stored_extra_pages = None
|
||||
|
||||
self.button_save_preview = None
|
||||
@@ -236,17 +271,22 @@ def pages_in_preferred_order(pages):
|
||||
def create_ui(container, button, tabname):
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.pages_contents = []
|
||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||
ui.tabname = tabname
|
||||
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title):
|
||||
page_id = page.title.lower().replace(" ", "_")
|
||||
|
||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||
with gr.Tab(page.title, id=page_id):
|
||||
elem_id = f"{tabname}_{page_id}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
|
||||
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
@@ -254,19 +294,33 @@ def create_ui(container, button, tabname):
|
||||
|
||||
def toggle_visibility(is_visible):
|
||||
is_visible = not is_visible
|
||||
|
||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
||||
|
||||
def fill_tabs(is_empty):
|
||||
"""Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
|
||||
|
||||
if not ui.pages_contents:
|
||||
refresh()
|
||||
|
||||
if is_empty:
|
||||
return True, *ui.pages_contents
|
||||
|
||||
return True, *[gr.update() for _ in ui.pages_contents]
|
||||
|
||||
state_visible = gr.State(value=False)
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
|
||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
|
||||
|
||||
state_empty = gr.State(value=True)
|
||||
button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
|
||||
|
||||
def refresh():
|
||||
res = []
|
||||
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
res.append(pg.create_html(ui.tabname))
|
||||
|
||||
return res
|
||||
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||
|
||||
return ui.pages_contents
|
||||
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
@@ -296,18 +350,13 @@ def setup_ui(ui, gallery):
|
||||
|
||||
is_allowed = False
|
||||
for extra_page in ui.stored_extra_pages:
|
||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
||||
if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
|
||||
is_allowed = True
|
||||
break
|
||||
|
||||
assert is_allowed, f'writing to {filename} is not allowed'
|
||||
|
||||
if geninfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
pnginfo_data.add_text('parameters', geninfo)
|
||||
image.save(filename, pnginfo=pnginfo_data)
|
||||
else:
|
||||
image.save(filename)
|
||||
save_image_with_geninfo(image, geninfo, filename)
|
||||
|
||||
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
||||
|
||||
|
||||
208
modules/ui_loadsave.py
Normal file
208
modules/ui_loadsave.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
|
||||
class UiLoadsave:
|
||||
"""allows saving and restorig default values for gradio components"""
|
||||
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.ui_settings = {}
|
||||
self.component_mapping = {}
|
||||
self.error_loading = False
|
||||
self.finalized_ui = False
|
||||
|
||||
self.ui_defaults_view = None
|
||||
self.ui_defaults_apply = None
|
||||
self.ui_defaults_review = None
|
||||
|
||||
try:
|
||||
if os.path.exists(self.filename):
|
||||
self.ui_settings = self.read_from_file()
|
||||
except Exception as e:
|
||||
self.error_loading = True
|
||||
errors.display(e, "loading settings")
|
||||
|
||||
def add_component(self, path, x):
|
||||
"""adds component to the registry of tracked components"""
|
||||
|
||||
assert not self.finalized_ui
|
||||
|
||||
def apply_field(obj, field, condition=None, init_field=None):
|
||||
key = f"{path}/{field}"
|
||||
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
saved_value = self.ui_settings.get(key, None)
|
||||
if saved_value is None:
|
||||
self.ui_settings[key] = getattr(obj, field)
|
||||
elif condition and not condition(saved_value):
|
||||
pass
|
||||
else:
|
||||
setattr(obj, field, saved_value)
|
||||
if init_field is not None:
|
||||
init_field(saved_value)
|
||||
|
||||
if field == 'value' and key not in self.component_mapping:
|
||||
self.component_mapping[key] = x
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
|
||||
if type(x) == gr.Slider:
|
||||
apply_field(x, 'value')
|
||||
apply_field(x, 'minimum')
|
||||
apply_field(x, 'maximum')
|
||||
apply_field(x, 'step')
|
||||
|
||||
if type(x) == gr.Radio:
|
||||
apply_field(x, 'value', lambda val: val in x.choices)
|
||||
|
||||
if type(x) == gr.Checkbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Textbox:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Number:
|
||||
apply_field(x, 'value')
|
||||
|
||||
if type(x) == gr.Dropdown:
|
||||
def check_dropdown(val):
|
||||
if getattr(x, 'multiselect', False):
|
||||
return all(value in x.choices for value in val)
|
||||
else:
|
||||
return val in x.choices
|
||||
|
||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||
|
||||
def check_tab_id(tab_id):
|
||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||
if type(tab_id) == str:
|
||||
tab_ids = [t.id for t in tab_items]
|
||||
return tab_id in tab_ids
|
||||
elif type(tab_id) == int:
|
||||
return 0 <= tab_id < len(tab_items)
|
||||
else:
|
||||
return False
|
||||
|
||||
if type(x) == gr.Tabs:
|
||||
apply_field(x, 'selected', check_tab_id)
|
||||
|
||||
def add_block(self, x, path=""):
|
||||
"""adds all components inside a gradio block x to the registry of tracked components"""
|
||||
|
||||
if hasattr(x, 'children'):
|
||||
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
||||
# Tabs element can't have a label, have to use elem_id instead
|
||||
self.add_component(f"{path}/Tabs@{x.elem_id}", x)
|
||||
for c in x.children:
|
||||
self.add_block(c, path)
|
||||
elif x.label is not None:
|
||||
self.add_component(f"{path}/{x.label}", x)
|
||||
|
||||
def read_from_file(self):
|
||||
with open(self.filename, "r", encoding="utf8") as file:
|
||||
return json.load(file)
|
||||
|
||||
def write_to_file(self, current_ui_settings):
|
||||
with open(self.filename, "w", encoding="utf8") as file:
|
||||
json.dump(current_ui_settings, file, indent=4)
|
||||
|
||||
def dump_defaults(self):
|
||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||
|
||||
if self.error_loading and os.path.exists(self.filename):
|
||||
return
|
||||
|
||||
self.write_to_file(self.ui_settings)
|
||||
|
||||
def iter_changes(self, current_ui_settings, values):
|
||||
"""
|
||||
given a dictionary with defaults from a file and current values from gradio elements, returns
|
||||
an iterator over tuples of values that are not the same between the file and the current;
|
||||
tuple contents are: path, old value, new value
|
||||
"""
|
||||
|
||||
for (path, component), new_value in zip(self.component_mapping.items(), values):
|
||||
old_value = current_ui_settings.get(path)
|
||||
|
||||
choices = getattr(component, 'choices', None)
|
||||
if isinstance(new_value, int) and choices:
|
||||
if new_value >= len(choices):
|
||||
continue
|
||||
|
||||
new_value = choices[new_value]
|
||||
|
||||
if new_value == old_value:
|
||||
continue
|
||||
|
||||
if old_value is None and new_value == '' or new_value == []:
|
||||
continue
|
||||
|
||||
yield path, old_value, new_value
|
||||
|
||||
def ui_view(self, *values):
|
||||
text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
|
||||
|
||||
for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
|
||||
if old_value is None:
|
||||
old_value = "<span class='ui-defaults-none'>None</span>"
|
||||
|
||||
text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
|
||||
|
||||
if len(text) == 1:
|
||||
text.append("<tr><td colspan=3>No changes</td></tr>")
|
||||
|
||||
text.append("</tbody>")
|
||||
return "".join(text)
|
||||
|
||||
def ui_apply(self, *values):
|
||||
num_changed = 0
|
||||
|
||||
current_ui_settings = self.read_from_file()
|
||||
|
||||
for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
|
||||
num_changed += 1
|
||||
current_ui_settings[path] = new_value
|
||||
|
||||
if num_changed == 0:
|
||||
return "No changes."
|
||||
|
||||
self.write_to_file(current_ui_settings)
|
||||
|
||||
return f"Wrote {num_changed} changes."
|
||||
|
||||
def create_ui(self):
|
||||
"""creates ui elements for editing defaults UI, without adding any logic to them"""
|
||||
|
||||
gr.HTML(
|
||||
f"This page allows you to change default values in UI elements on other tabs.<br />"
|
||||
f"Make your changes, press 'View changes' to review the changed default values,<br />"
|
||||
f"then press 'Apply' to write them to {self.filename}.<br />"
|
||||
f"New defaults will apply after you restart the UI.<br />"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
|
||||
self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
|
||||
|
||||
self.ui_defaults_review = gr.HTML("")
|
||||
|
||||
def setup_ui(self):
|
||||
"""adds logic to elements created with create_ui; all add_block class must be made before this"""
|
||||
|
||||
assert not self.finalized_ui
|
||||
self.finalized_ui = True
|
||||
|
||||
self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
||||
self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
||||
@@ -1,5 +1,5 @@
|
||||
import gradio as gr
|
||||
from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
|
||||
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ def create_ui():
|
||||
with gr.Row().style(equal_height=False, variant='compact'):
|
||||
with gr.Column(variant='compact'):
|
||||
with gr.Tabs(elem_id="mode_extras"):
|
||||
with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
|
||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||
|
||||
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
||||
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||
|
||||
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||
with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
||||
|
||||
@@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename):
|
||||
|
||||
def check_tmp_file(gradio, filename):
|
||||
if hasattr(gradio, 'temp_file_sets'):
|
||||
return any([filename in fileset for fileset in gradio.temp_file_sets])
|
||||
return any(filename in fileset for fileset in gradio.temp_file_sets)
|
||||
|
||||
if hasattr(gradio, 'temp_dirs'):
|
||||
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
||||
@@ -36,7 +36,7 @@ def save_pil_to_file(pil_image, dir=None):
|
||||
if already_saved_as and os.path.isfile(already_saved_as):
|
||||
register_tmp_file(shared.demo, already_saved_as)
|
||||
|
||||
file_obj = Savedfile(already_saved_as)
|
||||
file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
|
||||
return file_obj
|
||||
|
||||
if shared.opts.temp_dir != "":
|
||||
@@ -72,7 +72,7 @@ def cleanup_tmpdr():
|
||||
if temp_dir == "" or not os.path.isdir(temp_dir):
|
||||
return
|
||||
|
||||
for root, dirs, files in os.walk(temp_dir, topdown=False):
|
||||
for root, _, files in os.walk(temp_dir, topdown=False):
|
||||
for name in files:
|
||||
_, extension = os.path.splitext(name)
|
||||
if extension != ".png":
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
@@ -43,9 +41,9 @@ class Upscaler:
|
||||
os.makedirs(self.model_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import cv2 # noqa: F401
|
||||
self.can_tile = True
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -57,7 +55,7 @@ class Upscaler:
|
||||
dest_w = int(img.width * scale)
|
||||
dest_h = int(img.height * scale)
|
||||
|
||||
for i in range(3):
|
||||
for _ in range(3):
|
||||
shape = (img.width, img.height)
|
||||
|
||||
img = self.do_upscale(img, selected_model)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
||||
from transformers import BertPreTrainedModel, BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
@@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
||||
config_class= RobertaSeriesConfig
|
||||
|
||||
Reference in New Issue
Block a user