initial commit
This commit is contained in:
BIN
modules/Roboto-Regular.ttf
Executable file
BIN
modules/Roboto-Regular.ttf
Executable file
Binary file not shown.
880
modules/api/api.py
Executable file
880
modules/api/api.py
Executable file
@@ -0,0 +1,880 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
import ipaddress
|
||||
import requests
|
||||
import gradio as gr
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images, process_extra_images
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
from PIL import PngImagePlugin
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import Any, Union, get_origin, get_args
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from contextlib import closing
|
||||
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
return [script.title().lower() for script in scripts].index(name.lower())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=400, detail="Sampler not found")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||
return reqDict
|
||||
|
||||
|
||||
def verify_url(url):
|
||||
"""Returns True if the url refers to a global resource."""
|
||||
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
domain_name = parsed_url.netloc
|
||||
host = socket.gethostbyname_ex(domain_name)
|
||||
for ip in host[2]:
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
if not ip_addr.is_global:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
if not opts.api_enable_requests:
|
||||
raise HTTPException(status_code=500, detail="Requests not allowed")
|
||||
|
||||
if opts.api_forbid_local_requests and not verify_url(encoding):
|
||||
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
||||
|
||||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
image = images.read(BytesIO(response.content))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = images.read(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
if image.mode in ("RGBA", "P"):
|
||||
image = image.convert("RGB")
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
})
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
rich_available = False
|
||||
try:
|
||||
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
|
||||
import anyio # importing just so it can be placed on silent list
|
||||
import starlette # importing just so it can be placed on silent list
|
||||
from rich.console import Console
|
||||
console = Console()
|
||||
rich_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code=res.status_code,
|
||||
ver=req.scope.get('http_version', '0.0'),
|
||||
cli=req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot=req.scope.get('scheme', 'err'),
|
||||
method=req.scope.get('method', 'err'),
|
||||
endpoint=endpoint,
|
||||
duration=duration,
|
||||
))
|
||||
return res
|
||||
|
||||
def handle_exception(request: Request, e: Exception):
|
||||
err = {
|
||||
"error": type(e).__name__,
|
||||
"detail": vars(e).get('detail', ''),
|
||||
"body": vars(e).get('body', ''),
|
||||
"errors": str(e),
|
||||
}
|
||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||
message = f"API error: {request.method}: {request.url} {err}"
|
||||
if rich_available:
|
||||
print(message)
|
||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||
else:
|
||||
errors.report(message, exc_info=True)
|
||||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
|
||||
|
||||
@app.middleware("http")
|
||||
async def exception_handling(request: Request, call_next):
|
||||
try:
|
||||
return await call_next(request)
|
||||
except Exception as e:
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def fastapi_exception_handler(request: Request, e: Exception):
|
||||
return handle_exception(request, e)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, e: HTTPException):
|
||||
return handle_exception(request, e)
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credentials = {}
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credentials[user] = password
|
||||
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
#api_middleware(self.app) # FIXME: (legacy) this will have to be fixed
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/sd-modules", self.get_sd_vaes_and_text_encoders, methods=["GET"], response_model=list[models.SDModuleItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
||||
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
||||
|
||||
if shared.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
|
||||
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
txt2img_script_runner = scripts.scripts_txt2img
|
||||
img2img_script_runner = scripts.scripts_img2img
|
||||
|
||||
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
|
||||
ui.create_ui()
|
||||
|
||||
if not txt2img_script_runner.scripts:
|
||||
txt2img_script_runner.initialize_scripts(False)
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
|
||||
|
||||
if not img2img_script_runner.scripts:
|
||||
img2img_script_runner.initialize_scripts(True)
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||
|
||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
|
||||
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
||||
if credentials.username in self.credentials:
|
||||
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
||||
return True
|
||||
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def get_selectable_script(self, script_name, script_runner):
|
||||
if script_name is None or script_name == "":
|
||||
return None, None
|
||||
|
||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||
script = script_runner.selectable_scripts[script_idx]
|
||||
return script, script_idx
|
||||
|
||||
def get_scripts_list(self):
|
||||
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||
|
||||
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||
|
||||
def get_script_info(self):
|
||||
res = []
|
||||
|
||||
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
||||
res += [script.api_info for script in script_list if script.api_info is not None]
|
||||
|
||||
return res
|
||||
|
||||
def get_script(self, script_name, script_runner):
|
||||
if script_name is None or script_name == "":
|
||||
return None, None
|
||||
|
||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||
return script_runner.scripts[script_idx]
|
||||
|
||||
def init_default_script_args(self, script_runner):
|
||||
#find max idx from the scripts in runner and generate a none array to init script_args
|
||||
last_arg_index = 1
|
||||
for script in script_runner.scripts:
|
||||
if last_arg_index < script.args_to:
|
||||
last_arg_index = script.args_to
|
||||
# None everywhere except position 0 to initialize script args
|
||||
script_args = [None]*last_arg_index
|
||||
script_args[0] = 0
|
||||
|
||||
# get default values
|
||||
with gr.Blocks(): # will throw errors calling ui function without this
|
||||
for script in script_runner.scripts:
|
||||
if script.ui(script.is_img2img):
|
||||
ui_default_values = []
|
||||
for elem in script.ui(script.is_img2img):
|
||||
ui_default_values.append(elem.value)
|
||||
script_args[script.args_from:script.args_to] = ui_default_values
|
||||
return script_args
|
||||
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
|
||||
script_args = default_script_args.copy()
|
||||
|
||||
if input_script_args is not None:
|
||||
for index, value in input_script_args.items():
|
||||
script_args[index] = value
|
||||
|
||||
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||
if selectable_scripts:
|
||||
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||
script_args[0] = selectable_idx + 1
|
||||
|
||||
# Now check for always on scripts
|
||||
if request.alwayson_scripts:
|
||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||
if alwayson_script is None:
|
||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||
# Selectable script in always on script param check
|
||||
if alwayson_script.alwayson is False:
|
||||
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||
# min between arg length in scriptrunner and arg length in the request
|
||||
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
||||
"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
|
||||
|
||||
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
||||
|
||||
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
|
||||
"""
|
||||
|
||||
if not request.infotext:
|
||||
return {}
|
||||
|
||||
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
|
||||
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this
|
||||
params = infotext_utils.parse_generation_parameters(request.infotext)
|
||||
|
||||
def get_base_type(annotation):
|
||||
origin = get_origin(annotation)
|
||||
|
||||
if origin is Union: # represents Optional
|
||||
args = get_args(annotation) # filter out NoneType
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1: # annotation was Optional[X]
|
||||
return non_none_args[0]
|
||||
|
||||
return annotation
|
||||
|
||||
def get_field_value(field, params):
|
||||
value = field.function(params) if field.function else params.get(field.label)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if field.api in request.__fields__:
|
||||
target_type = get_base_type(request.__fields__[field.api].annotation) # extract type from Optional[X]
|
||||
else:
|
||||
target_type = type(field.component.value)
|
||||
|
||||
if target_type == type(None):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
|
||||
value = value.get('value')
|
||||
|
||||
if value is not None and not isinstance(value, target_type):
|
||||
value = target_type(value)
|
||||
|
||||
return value
|
||||
|
||||
for field in possible_fields:
|
||||
if not field.api:
|
||||
continue
|
||||
|
||||
if field.api in set_fields:
|
||||
continue
|
||||
|
||||
value = get_field_value(field, params)
|
||||
if value is not None:
|
||||
setattr(request, field.api, value)
|
||||
|
||||
if request.override_settings is None:
|
||||
request.override_settings = {}
|
||||
|
||||
overridden_settings = infotext_utils.get_override_settings(params)
|
||||
for _, setting_name, value in overridden_settings:
|
||||
if setting_name not in request.override_settings:
|
||||
request.override_settings[setting_name] = value
|
||||
|
||||
if script_runner is not None and mentioned_script_args is not None:
|
||||
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
|
||||
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
|
||||
|
||||
for field, index in script_fields:
|
||||
value = get_field_value(field, params)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
mentioned_script_args[index] = value
|
||||
|
||||
return params
|
||||
|
||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
|
||||
|
||||
script_runner = scripts.scripts_txt2img
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(sampler),
|
||||
"do_not_save_samples": not txt2imgreq.save_images,
|
||||
"do_not_save_grid": not txt2imgreq.save_images,
|
||||
})
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
if not populate.scheduler and scheduler != "Automatic":
|
||||
populate.scheduler = scheduler
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.is_api = True
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
process_extra_images(processed)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images + processed.extra_images)) if send_images else []
|
||||
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||
task_id = img2imgreq.force_task_id or create_task_id("img2img")
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
script_runner = scripts.scripts_img2img
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||
sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(sampler),
|
||||
"do_not_save_samples": not img2imgreq.save_images,
|
||||
"do_not_save_grid": not img2imgreq.save_images,
|
||||
"mask": mask,
|
||||
})
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
if not populate.scheduler and scheduler != "Automatic":
|
||||
populate.scheduler = scheduler
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.is_api = True
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
process_extra_images(processed)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images + processed.extra_images)) if send_images else []
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
image_list = reqDict.pop('imageList', [])
|
||||
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
params = infotext_utils.parse_generation_parameters(geninfo)
|
||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
shared.state.set_current_image()
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
|
||||
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
if interrogatereq.model == "clip":
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
elif interrogatereq.model == "deepdanbooru":
|
||||
processed = deepbooru.model.tag(img)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return models.InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def unloadapi(self):
|
||||
sd_models.unload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def reloadapi(self):
|
||||
sd_models.send_model_to_device(shared.sd_model)
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
def get_config(self):
|
||||
from modules.sysinfo import get_config
|
||||
return get_config()
|
||||
|
||||
def set_config(self, req: dict[str, Any]):
|
||||
from modules.sysinfo import set_config
|
||||
set_config(req)
|
||||
|
||||
def get_cmd_flags(self):
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
def get_samplers(self):
|
||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_schedulers(self):
|
||||
return [
|
||||
{
|
||||
"name": scheduler.name,
|
||||
"label": scheduler.label,
|
||||
"aliases": scheduler.aliases,
|
||||
"default_rho": scheduler.default_rho,
|
||||
"need_inner_model": scheduler.need_inner_model,
|
||||
}
|
||||
for scheduler in sd_schedulers.schedulers]
|
||||
|
||||
def get_upscalers(self):
|
||||
return [
|
||||
{
|
||||
"name": upscaler.name,
|
||||
"model_name": upscaler.scaler.model_name,
|
||||
"model_path": upscaler.data_path,
|
||||
"model_url": None,
|
||||
"scale": upscaler.scale,
|
||||
}
|
||||
for upscaler in shared.sd_upscalers
|
||||
]
|
||||
|
||||
def get_latent_upscale_modes(self):
|
||||
return [
|
||||
{
|
||||
"name": upscale_mode,
|
||||
}
|
||||
for upscale_mode in [*(shared.latent_upscale_modes or {})]
|
||||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
import modules.sd_models as sd_models
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": getattr(x, 'config', None)} for x in sd_models.checkpoints_list.values()]
|
||||
|
||||
def get_sd_vaes_and_text_encoders(self):
|
||||
from modules_forge.main_entry import module_list
|
||||
return [{"model_name": x, "filename": module_list[x]} for x in module_list.keys()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
||||
def get_face_restorers(self):
|
||||
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
||||
|
||||
def get_realesrgan_models(self):
|
||||
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||
|
||||
def get_prompt_styles(self):
|
||||
styleList = []
|
||||
for k in shared.prompt_styles.styles:
|
||||
style = shared.prompt_styles.styles[k]
|
||||
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
||||
|
||||
return styleList
|
||||
|
||||
def get_embeddings(self):
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(self.embedding_db.word_embeddings),
|
||||
"skipped": convert_embeddings(self.embedding_db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def refresh_vae(self):
|
||||
with self.queue_lock:
|
||||
shared_items.refresh_vae_list()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_embedding")
|
||||
filename = modules.textual_inversion.textual_inversion.create_embedding(**args) # create empty embedding
|
||||
self.embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) # reload embeddings so new one can be immediately used
|
||||
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_hypernetwork")
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||
except AssertionError as e:
|
||||
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||
finally:
|
||||
shared.state.end()
|
||||
|
||||
def get_memory(self):
|
||||
try:
|
||||
import os
|
||||
import psutil
|
||||
process = psutil.Process(os.getpid())
|
||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||
ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
|
||||
except Exception as err:
|
||||
ram = { 'error': f'{err}' }
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
s = torch.cuda.mem_get_info()
|
||||
system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
|
||||
s = dict(torch.cuda.memory_stats(shared.device))
|
||||
allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
|
||||
reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
|
||||
active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
|
||||
inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
|
||||
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
|
||||
cuda = {
|
||||
'system': system,
|
||||
'active': active,
|
||||
'allocated': allocated,
|
||||
'reserved': reserved,
|
||||
'inactive': inactive,
|
||||
'events': warnings,
|
||||
}
|
||||
else:
|
||||
cuda = {'error': 'unavailable'}
|
||||
except Exception as err:
|
||||
cuda = {'error': f'{err}'}
|
||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||
|
||||
def get_extensions_list(self):
|
||||
from modules import extensions
|
||||
extensions.list_extensions()
|
||||
ext_list = []
|
||||
for ext in extensions.extensions:
|
||||
ext: extensions.Extension
|
||||
ext.read_info_from_repo()
|
||||
if ext.remote is not None:
|
||||
ext_list.append({
|
||||
"name": ext.name,
|
||||
"remote": ext.remote,
|
||||
"branch": ext.branch,
|
||||
"commit_hash":ext.commit_hash,
|
||||
"commit_date":ext.commit_date,
|
||||
"version":ext.version,
|
||||
"enabled":ext.enabled
|
||||
})
|
||||
return ext_list
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=server_name,
|
||||
port=port,
|
||||
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
|
||||
root_path=root_path,
|
||||
ssl_keyfile=shared.cmd_opts.tls_keyfile,
|
||||
ssl_certfile=shared.cmd_opts.tls_certfile
|
||||
)
|
||||
|
||||
def kill_webui(self):
|
||||
restart.stop_program()
|
||||
|
||||
def restart_webui(self):
|
||||
if restart.is_restartable():
|
||||
restart.restart_program()
|
||||
return Response(status_code=501)
|
||||
|
||||
def stop_webui(request):
|
||||
shared.state.server_command = "stop"
|
||||
return Response("Stopping.")
|
||||
|
||||
336
modules/api/models.py
Executable file
336
modules/api/models.py
Executable file
@@ -0,0 +1,336 @@
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, create_model, ConfigDict
|
||||
from typing import Any, Optional, Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
# "do_not_save_samples",
|
||||
# "do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
field_type = v.annotation
|
||||
|
||||
if field_type == 'Image':
|
||||
# images are sent as base64 strings via API
|
||||
field_type = 'str'
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=None if isinstance(v.default, property) else v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, __config__=ConfigDict(populate_by_name=True, frozen=False), **fields)
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[
|
||||
{"key": "sampler_index", "type": str, "default": "Euler"},
|
||||
{"key": "script_name", "type": str | None, "default": None},
|
||||
{"key": "script_args", "type": list, "default": []},
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str | None, "default": None},
|
||||
{"key": "infotext", "type": str | None, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[
|
||||
{"key": "sampler_index", "type": str, "default": "Euler"},
|
||||
{"key": "init_images", "type": list | None, "default": None},
|
||||
{"key": "denoising_strength", "type": float, "default": 0.75},
|
||||
{"key": "mask", "type": str | None, "default": None},
|
||||
{"key": "include_init_images", "type": bool, "default": False, "exclude" : True},
|
||||
{"key": "script_name", "type": str | None, "default": None},
|
||||
{"key": "script_args", "type": list, "default": []},
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str | None, "default": None},
|
||||
{"key": "infotext", "type": str | None, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] | None = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: list[str] | None = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", gt=0, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str | None = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
|
||||
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str | None = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str | None = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||
|
||||
class CreateResponse(BaseModel):
|
||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
|
||||
|
||||
if metadata is not None:
|
||||
fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
|
||||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(parser)['_option_string_actions']
|
||||
for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str | None
|
||||
if _options[key].default is not None:
|
||||
_type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: list[str] = Field(title="Aliases")
|
||||
options: dict[str, Any] = Field(title="Options")
|
||||
|
||||
class SchedulerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
label: str = Field(title="Label")
|
||||
aliases: Optional[list[str]] = Field(title="Aliases")
|
||||
default_rho: Optional[float] = Field(title="Default Rho")
|
||||
need_inner_model: Optional[bool] = Field(title="Needs Inner Model")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
scale: Optional[float] = Field(title="Scale")
|
||||
|
||||
class LatentUpscalerModeItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: Optional[str] = Field(default=None, title="Config file")
|
||||
|
||||
class SDModuleItem(BaseModel):
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
model_name: str = Field(title="Model Name")
|
||||
filename: str = Field(title="Filename")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
|
||||
class FaceRestorerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
cmd_dir: Optional[str] = Field(title="Path")
|
||||
|
||||
class RealesrganItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
|
||||
class PromptStyleItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
|
||||
class ScriptsList(BaseModel):
|
||||
txt2img: list | None = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list | None = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||
|
||||
|
||||
class ScriptArg(BaseModel):
|
||||
label: str | None = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
|
||||
|
||||
class ScriptInfo(BaseModel):
|
||||
name: str | None = Field(default=None, title="Name", description="Script name")
|
||||
is_alwayson: bool | None = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool | None = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
class ExtensionItem(BaseModel):
|
||||
name: str = Field(title="Name", description="Extension name")
|
||||
remote: str = Field(title="Remote", description="Extension Repository URL")
|
||||
branch: str = Field(title="Branch", description="Extension Repository Branch")
|
||||
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
||||
version: str = Field(title="Version", description="Extension Version")
|
||||
commit_date: int = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||
123
modules/cache.py
Executable file
123
modules/cache.py
Executable file
@@ -0,0 +1,123 @@
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
import threading
|
||||
|
||||
import diskcache
|
||||
import tqdm
|
||||
|
||||
from modules.paths import data_path, script_path
|
||||
|
||||
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
||||
cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache"))
|
||||
caches = {}
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def dump_cache():
|
||||
"""old function for dumping cache to disk; does nothing since diskcache."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def make_cache(subsection: str) -> diskcache.Cache:
|
||||
return diskcache.Cache(
|
||||
os.path.join(cache_dir, subsection),
|
||||
size_limit=2**32, # 4 GB, culling oldest first
|
||||
disk_min_file_size=2**18, # keep up to 256KB in Sqlite
|
||||
)
|
||||
|
||||
|
||||
def convert_old_cached_data():
|
||||
try:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except Exception:
|
||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||
print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')
|
||||
return
|
||||
|
||||
total_count = sum(len(keyvalues) for keyvalues in data.values())
|
||||
|
||||
with tqdm.tqdm(total=total_count, desc="converting cache") as progress:
|
||||
for subsection, keyvalues in data.items():
|
||||
cache_obj = caches.get(subsection)
|
||||
if cache_obj is None:
|
||||
cache_obj = make_cache(subsection)
|
||||
caches[subsection] = cache_obj
|
||||
|
||||
for key, value in keyvalues.items():
|
||||
cache_obj[key] = value
|
||||
progress.update(1)
|
||||
|
||||
|
||||
def cache(subsection):
|
||||
"""
|
||||
Retrieves or initializes a cache for a specific subsection.
|
||||
|
||||
Parameters:
|
||||
subsection (str): The subsection identifier for the cache.
|
||||
|
||||
Returns:
|
||||
diskcache.Cache: The cache data for the specified subsection.
|
||||
"""
|
||||
|
||||
cache_obj = caches.get(subsection)
|
||||
if not cache_obj:
|
||||
with cache_lock:
|
||||
if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):
|
||||
convert_old_cached_data()
|
||||
|
||||
cache_obj = caches.get(subsection)
|
||||
if not cache_obj:
|
||||
cache_obj = make_cache(subsection)
|
||||
caches[subsection] = cache_obj
|
||||
|
||||
return cache_obj
|
||||
|
||||
|
||||
def cached_data_for_file(subsection, title, filename, func):
|
||||
"""
|
||||
Retrieves or generates data for a specific file, using a caching mechanism.
|
||||
|
||||
Parameters:
|
||||
subsection (str): The subsection of the cache to use.
|
||||
title (str): The title of the data entry in the subsection of the cache.
|
||||
filename (str): The path to the file to be checked for modifications.
|
||||
func (callable): A function that generates the data if it is not available in the cache.
|
||||
|
||||
Returns:
|
||||
dict or None: The cached or generated data, or None if data generation fails.
|
||||
|
||||
The `cached_data_for_file` function implements a caching mechanism for data stored in files.
|
||||
It checks if the data associated with the given `title` is present in the cache and compares the
|
||||
modification time of the file with the cached modification time. If the file has been modified,
|
||||
the cache is considered invalid and the data is regenerated using the provided `func`.
|
||||
Otherwise, the cached data is returned.
|
||||
|
||||
If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
|
||||
or cached data is returned as a dictionary.
|
||||
"""
|
||||
|
||||
existing_cache = cache(subsection)
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
|
||||
entry = existing_cache.get(title)
|
||||
if entry:
|
||||
cached_mtime = entry.get("mtime", 0)
|
||||
if ondisk_mtime > cached_mtime:
|
||||
entry = None
|
||||
|
||||
if not entry or 'value' not in entry:
|
||||
value = func()
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
entry = {'mtime': ondisk_mtime, 'value': value}
|
||||
existing_cache[title] = entry
|
||||
|
||||
dump_cache()
|
||||
|
||||
return entry['value']
|
||||
133
modules/call_queue.py
Executable file
133
modules/call_queue.py
Executable file
@@ -0,0 +1,133 @@
|
||||
import os.path
|
||||
from functools import wraps
|
||||
import html
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from modules_forge import main_thread
|
||||
from modules import shared, progress, errors, devices, fifo_lock, profiling
|
||||
|
||||
queue_lock = fifo_lock.FIFOLock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
@wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
|
||||
# if the first argument is a string that says "task(...)", it is treated as a job id
|
||||
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
|
||||
id_task = args[0]
|
||||
progress.add_task_to_queue(id_task)
|
||||
else:
|
||||
id_task = None
|
||||
|
||||
with queue_lock:
|
||||
shared.state.begin(job=id_task)
|
||||
progress.start_task(id_task)
|
||||
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
progress.record_results(id_task, res)
|
||||
finally:
|
||||
progress.finish_task(id_task)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
@wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
finally:
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.stopping_generation = False
|
||||
shared.state.job_count = 0
|
||||
shared.state.job = ""
|
||||
return res
|
||||
|
||||
return wrap_gradio_call_no_job(f, extra_outputs, add_stats)
|
||||
|
||||
|
||||
def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):
|
||||
@wraps(func)
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
if main_thread.last_exception is not None:
|
||||
e = main_thread.last_exception
|
||||
else:
|
||||
traceback.print_exc()
|
||||
print(e)
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
error_message = f'{type(e).__name__}: {e}'
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.1f} sec."
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m} min. "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = sys_peak/max(sys_total, 1) * 100
|
||||
|
||||
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
||||
toltip_r = "Reserved: total amount of video memory allocated by the Torch library "
|
||||
toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity"
|
||||
|
||||
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
||||
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
||||
text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
|
||||
|
||||
vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
|
||||
profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
|
||||
else:
|
||||
profiling_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
||||
150
modules/cmd_args.py
Executable file
150
modules/cmd_args.py
Executable file
@@ -0,0 +1,150 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
from pathlib import Path
|
||||
from backend.args import parser
|
||||
|
||||
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
|
||||
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
|
||||
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||
parser.add_argument("--skip-google-blockly", action='store_true', help="launch.py argument: do not initialize google blockly modules")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
||||
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||
parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--models-dir", type=normalized_filepath, default=None, help="base path where models are stored; overrides --data-dir")
|
||||
parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=normalized_filepath, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--text-encoder-dir", type=normalized_filepath, default=None, help="Path to directory with text encoder models")
|
||||
parser.add_argument("--gfpgan-dir", type=normalized_filepath, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything")
|
||||
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
||||
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=normalized_filepath, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=normalized_filepath, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT'))
|
||||
parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s), for Interrogate options.", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
|
||||
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
|
||||
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
|
||||
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
|
||||
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
|
||||
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=normalized_filepath, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
||||
parser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
||||
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
||||
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
||||
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
|
||||
|
||||
# Arguments added by forge.
|
||||
parser.add_argument(
|
||||
'--forge-ref-a1111-home',
|
||||
type=Path,
|
||||
help="Look for models in an existing A1111 checkout's path",
|
||||
default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet-dir",
|
||||
type=Path,
|
||||
help="Path to directory with ControlNet models",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet-preprocessor-models-dir",
|
||||
type=Path,
|
||||
help="Path to directory with annotator model directories",
|
||||
default=None,
|
||||
)
|
||||
64
modules/codeformer_model.py
Executable file
64
modules/codeformer_model.py
Executable file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_download_name = 'codeformer-v0.1.0.pth'
|
||||
|
||||
# used by e.g. postprocessing_codeformer.py
|
||||
codeformer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
expected_architecture='CodeFormer',
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
def get_device(self):
|
||||
return devices.device_codeformer
|
||||
|
||||
def restore(self, np_image, w: float | None = None):
|
||||
if w is None:
|
||||
w = getattr(shared.opts, "code_former_weight", 0.5)
|
||||
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, weight=w, adain=True)[0]
|
||||
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def setup_model(dirname: str) -> None:
|
||||
global codeformer
|
||||
try:
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
except Exception:
|
||||
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||
198
modules/config_states.py
Executable file
198
modules/config_states.py
Executable file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Supports saving and restoring webui and extensions from a known working set of commits
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import tqdm
|
||||
|
||||
from datetime import datetime
|
||||
import git
|
||||
|
||||
from modules import shared, extensions, errors
|
||||
from modules.paths_internal import script_path, config_states_dir
|
||||
|
||||
all_config_states = {}
|
||||
|
||||
|
||||
def list_config_states():
|
||||
global all_config_states
|
||||
|
||||
all_config_states.clear()
|
||||
os.makedirs(config_states_dir, exist_ok=True)
|
||||
|
||||
config_states = []
|
||||
for filename in os.listdir(config_states_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(config_states_dir, filename)
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
j = json.load(f)
|
||||
assert "created_at" in j, '"created_at" does not exist'
|
||||
j["filepath"] = path
|
||||
config_states.append(j)
|
||||
except Exception as e:
|
||||
print(f'[ERROR]: Config states {path}, {e}')
|
||||
|
||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||
|
||||
for cs in config_states:
|
||||
timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
||||
name = cs.get("name", "Config")
|
||||
full_name = f"{name}: {timestamp}"
|
||||
all_config_states[full_name] = cs
|
||||
|
||||
return all_config_states
|
||||
|
||||
|
||||
def get_webui_config():
|
||||
webui_repo = None
|
||||
|
||||
try:
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||
|
||||
webui_remote = None
|
||||
webui_commit_hash = None
|
||||
webui_commit_date = None
|
||||
webui_branch = None
|
||||
if webui_repo and not webui_repo.bare:
|
||||
try:
|
||||
webui_remote = next(webui_repo.remote().urls, None)
|
||||
head = webui_repo.head.commit
|
||||
webui_commit_date = webui_repo.head.commit.committed_date
|
||||
webui_commit_hash = head.hexsha
|
||||
webui_branch = webui_repo.active_branch.name
|
||||
|
||||
except Exception:
|
||||
webui_remote = None
|
||||
|
||||
return {
|
||||
"remote": webui_remote,
|
||||
"commit_hash": webui_commit_hash,
|
||||
"commit_date": webui_commit_date,
|
||||
"branch": webui_branch,
|
||||
}
|
||||
|
||||
|
||||
def get_extension_config():
|
||||
ext_config = {}
|
||||
|
||||
for ext in extensions.extensions:
|
||||
ext.read_info_from_repo()
|
||||
|
||||
entry = {
|
||||
"name": ext.name,
|
||||
"path": ext.path,
|
||||
"enabled": ext.enabled,
|
||||
"is_builtin": ext.is_builtin,
|
||||
"remote": ext.remote,
|
||||
"commit_hash": ext.commit_hash,
|
||||
"commit_date": ext.commit_date,
|
||||
"branch": ext.branch,
|
||||
"have_info_from_repo": ext.have_info_from_repo
|
||||
}
|
||||
|
||||
ext_config[ext.name] = entry
|
||||
|
||||
return ext_config
|
||||
|
||||
|
||||
def get_config():
|
||||
creation_time = datetime.now().timestamp()
|
||||
webui_config = get_webui_config()
|
||||
ext_config = get_extension_config()
|
||||
|
||||
return {
|
||||
"created_at": creation_time,
|
||||
"webui": webui_config,
|
||||
"extensions": ext_config
|
||||
}
|
||||
|
||||
|
||||
def restore_webui_config(config):
|
||||
print("* Restoring webui state...")
|
||||
|
||||
if "webui" not in config:
|
||||
print("Error: No webui data saved to config")
|
||||
return
|
||||
|
||||
webui_config = config["webui"]
|
||||
|
||||
if "commit_hash" not in webui_config:
|
||||
print("Error: No commit saved to webui config")
|
||||
return
|
||||
|
||||
webui_commit_hash = webui_config.get("commit_hash", None)
|
||||
webui_repo = None
|
||||
|
||||
try:
|
||||
if os.path.exists(os.path.join(script_path, ".git")):
|
||||
webui_repo = git.Repo(script_path)
|
||||
except Exception:
|
||||
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
||||
return
|
||||
|
||||
try:
|
||||
webui_repo.git.fetch(all=True)
|
||||
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||
except Exception:
|
||||
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
||||
|
||||
|
||||
def restore_extension_config(config):
|
||||
print("* Restoring extension state...")
|
||||
|
||||
if "extensions" not in config:
|
||||
print("Error: No extension data saved to config")
|
||||
return
|
||||
|
||||
ext_config = config["extensions"]
|
||||
|
||||
results = []
|
||||
disabled = []
|
||||
|
||||
for ext in tqdm.tqdm(extensions.extensions):
|
||||
if ext.is_builtin:
|
||||
continue
|
||||
|
||||
ext.read_info_from_repo()
|
||||
current_commit = ext.commit_hash
|
||||
|
||||
if ext.name not in ext_config:
|
||||
ext.disabled = True
|
||||
disabled.append(ext.name)
|
||||
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
||||
continue
|
||||
|
||||
entry = ext_config[ext.name]
|
||||
|
||||
if "commit_hash" in entry and entry["commit_hash"]:
|
||||
try:
|
||||
ext.fetch_and_reset_hard(entry["commit_hash"])
|
||||
ext.read_info_from_repo()
|
||||
if current_commit != entry["commit_hash"]:
|
||||
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
||||
except Exception as ex:
|
||||
results.append((ext, current_commit[:8], False, ex))
|
||||
else:
|
||||
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
||||
|
||||
if not entry.get("enabled", False):
|
||||
ext.disabled = True
|
||||
disabled.append(ext.name)
|
||||
else:
|
||||
ext.disabled = False
|
||||
|
||||
shared.opts.disabled_extensions = disabled
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
print("* Finished restoring extensions. Results:")
|
||||
for ext, prev_commit, success, result in results:
|
||||
if success:
|
||||
print(f" + {ext.name}: {prev_commit} -> {result}")
|
||||
else:
|
||||
print(f" ! {ext.name}: FAILURE ({result})")
|
||||
95
modules/dat_model.py
Executable file
95
modules/dat_model.py
Executable file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerDAT(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "DAT"
|
||||
self.user_path = user_path
|
||||
self.scalers = []
|
||||
super().__init__()
|
||||
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth", ".safetensors"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
for model in get_dat_models(self):
|
||||
if model.name in opts.dat_enabled_models:
|
||||
self.scalers.append(model)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load DAT model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="DAT",
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.DAT_tile,
|
||||
tile_overlap=opts.DAT_tile_overlap,
|
||||
)
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
hash_prefix=scaler.sha256,
|
||||
)
|
||||
|
||||
if os.path.getsize(scaler.local_data_path) < 200:
|
||||
# Re-download if the file is too small, probably an LFS pointer
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
hash_prefix=scaler.sha256,
|
||||
re_download=True,
|
||||
)
|
||||
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
|
||||
def get_dat_models(scaler):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="DAT x2",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51',
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x3",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth",
|
||||
scale=3,
|
||||
upscaler=scaler,
|
||||
sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197',
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x4",
|
||||
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e',
|
||||
),
|
||||
]
|
||||
109
modules/deepbooru.py
Executable file
109
modules/deepbooru.py
Executable file
@@ -0,0 +1,109 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, images, shared
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.load_device = memory_management.text_encoder_device()
|
||||
self.offload_device = memory_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if memory_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.patcher = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
ext_filter=[".pt"],
|
||||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
|
||||
self.model = deepbooru_model.DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(self.offload_device, self.dtype)
|
||||
|
||||
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
memory_management.load_models_gpu([self.patcher])
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
self.stop()
|
||||
|
||||
return res
|
||||
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad():
|
||||
x = torch.from_numpy(a).to(self.load_device, self.dtype)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < threshold:
|
||||
continue
|
||||
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
|
||||
probability_dict[tag] = probability
|
||||
|
||||
if alpha_sort:
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
res = []
|
||||
|
||||
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||
|
||||
res.append(tag_outformat)
|
||||
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
model = DeepDanbooru()
|
||||
678
modules/deepbooru_model.py
Executable file
678
modules/deepbooru_model.py
Executable file
@@ -0,0 +1,678 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import devices
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
class DeepDanbooruModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeepDanbooruModel, self).__init__()
|
||||
|
||||
self.tags = []
|
||||
|
||||
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||
|
||||
def forward(self, *inputs):
|
||||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
t_363 = self.n_Conv_1(t_362)
|
||||
t_364 = self.n_Conv_2(t_362)
|
||||
t_365 = F.relu(t_364)
|
||||
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||
t_366 = self.n_Conv_3(t_365_padded)
|
||||
t_367 = F.relu(t_366)
|
||||
t_368 = self.n_Conv_4(t_367)
|
||||
t_369 = torch.add(t_368, t_363)
|
||||
t_370 = F.relu(t_369)
|
||||
t_371 = self.n_Conv_5(t_370)
|
||||
t_372 = F.relu(t_371)
|
||||
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||
t_373 = self.n_Conv_6(t_372_padded)
|
||||
t_374 = F.relu(t_373)
|
||||
t_375 = self.n_Conv_7(t_374)
|
||||
t_376 = torch.add(t_375, t_370)
|
||||
t_377 = F.relu(t_376)
|
||||
t_378 = self.n_Conv_8(t_377)
|
||||
t_379 = F.relu(t_378)
|
||||
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||
t_380 = self.n_Conv_9(t_379_padded)
|
||||
t_381 = F.relu(t_380)
|
||||
t_382 = self.n_Conv_10(t_381)
|
||||
t_383 = torch.add(t_382, t_377)
|
||||
t_384 = F.relu(t_383)
|
||||
t_385 = self.n_Conv_11(t_384)
|
||||
t_386 = self.n_Conv_12(t_384)
|
||||
t_387 = F.relu(t_386)
|
||||
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||
t_388 = self.n_Conv_13(t_387_padded)
|
||||
t_389 = F.relu(t_388)
|
||||
t_390 = self.n_Conv_14(t_389)
|
||||
t_391 = torch.add(t_390, t_385)
|
||||
t_392 = F.relu(t_391)
|
||||
t_393 = self.n_Conv_15(t_392)
|
||||
t_394 = F.relu(t_393)
|
||||
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||
t_395 = self.n_Conv_16(t_394_padded)
|
||||
t_396 = F.relu(t_395)
|
||||
t_397 = self.n_Conv_17(t_396)
|
||||
t_398 = torch.add(t_397, t_392)
|
||||
t_399 = F.relu(t_398)
|
||||
t_400 = self.n_Conv_18(t_399)
|
||||
t_401 = F.relu(t_400)
|
||||
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||
t_402 = self.n_Conv_19(t_401_padded)
|
||||
t_403 = F.relu(t_402)
|
||||
t_404 = self.n_Conv_20(t_403)
|
||||
t_405 = torch.add(t_404, t_399)
|
||||
t_406 = F.relu(t_405)
|
||||
t_407 = self.n_Conv_21(t_406)
|
||||
t_408 = F.relu(t_407)
|
||||
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||
t_409 = self.n_Conv_22(t_408_padded)
|
||||
t_410 = F.relu(t_409)
|
||||
t_411 = self.n_Conv_23(t_410)
|
||||
t_412 = torch.add(t_411, t_406)
|
||||
t_413 = F.relu(t_412)
|
||||
t_414 = self.n_Conv_24(t_413)
|
||||
t_415 = F.relu(t_414)
|
||||
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||
t_416 = self.n_Conv_25(t_415_padded)
|
||||
t_417 = F.relu(t_416)
|
||||
t_418 = self.n_Conv_26(t_417)
|
||||
t_419 = torch.add(t_418, t_413)
|
||||
t_420 = F.relu(t_419)
|
||||
t_421 = self.n_Conv_27(t_420)
|
||||
t_422 = F.relu(t_421)
|
||||
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||
t_423 = self.n_Conv_28(t_422_padded)
|
||||
t_424 = F.relu(t_423)
|
||||
t_425 = self.n_Conv_29(t_424)
|
||||
t_426 = torch.add(t_425, t_420)
|
||||
t_427 = F.relu(t_426)
|
||||
t_428 = self.n_Conv_30(t_427)
|
||||
t_429 = F.relu(t_428)
|
||||
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||
t_430 = self.n_Conv_31(t_429_padded)
|
||||
t_431 = F.relu(t_430)
|
||||
t_432 = self.n_Conv_32(t_431)
|
||||
t_433 = torch.add(t_432, t_427)
|
||||
t_434 = F.relu(t_433)
|
||||
t_435 = self.n_Conv_33(t_434)
|
||||
t_436 = F.relu(t_435)
|
||||
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||
t_437 = self.n_Conv_34(t_436_padded)
|
||||
t_438 = F.relu(t_437)
|
||||
t_439 = self.n_Conv_35(t_438)
|
||||
t_440 = torch.add(t_439, t_434)
|
||||
t_441 = F.relu(t_440)
|
||||
t_442 = self.n_Conv_36(t_441)
|
||||
t_443 = self.n_Conv_37(t_441)
|
||||
t_444 = F.relu(t_443)
|
||||
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||
t_445 = self.n_Conv_38(t_444_padded)
|
||||
t_446 = F.relu(t_445)
|
||||
t_447 = self.n_Conv_39(t_446)
|
||||
t_448 = torch.add(t_447, t_442)
|
||||
t_449 = F.relu(t_448)
|
||||
t_450 = self.n_Conv_40(t_449)
|
||||
t_451 = F.relu(t_450)
|
||||
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||
t_452 = self.n_Conv_41(t_451_padded)
|
||||
t_453 = F.relu(t_452)
|
||||
t_454 = self.n_Conv_42(t_453)
|
||||
t_455 = torch.add(t_454, t_449)
|
||||
t_456 = F.relu(t_455)
|
||||
t_457 = self.n_Conv_43(t_456)
|
||||
t_458 = F.relu(t_457)
|
||||
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||
t_459 = self.n_Conv_44(t_458_padded)
|
||||
t_460 = F.relu(t_459)
|
||||
t_461 = self.n_Conv_45(t_460)
|
||||
t_462 = torch.add(t_461, t_456)
|
||||
t_463 = F.relu(t_462)
|
||||
t_464 = self.n_Conv_46(t_463)
|
||||
t_465 = F.relu(t_464)
|
||||
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||
t_466 = self.n_Conv_47(t_465_padded)
|
||||
t_467 = F.relu(t_466)
|
||||
t_468 = self.n_Conv_48(t_467)
|
||||
t_469 = torch.add(t_468, t_463)
|
||||
t_470 = F.relu(t_469)
|
||||
t_471 = self.n_Conv_49(t_470)
|
||||
t_472 = F.relu(t_471)
|
||||
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||
t_473 = self.n_Conv_50(t_472_padded)
|
||||
t_474 = F.relu(t_473)
|
||||
t_475 = self.n_Conv_51(t_474)
|
||||
t_476 = torch.add(t_475, t_470)
|
||||
t_477 = F.relu(t_476)
|
||||
t_478 = self.n_Conv_52(t_477)
|
||||
t_479 = F.relu(t_478)
|
||||
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||
t_480 = self.n_Conv_53(t_479_padded)
|
||||
t_481 = F.relu(t_480)
|
||||
t_482 = self.n_Conv_54(t_481)
|
||||
t_483 = torch.add(t_482, t_477)
|
||||
t_484 = F.relu(t_483)
|
||||
t_485 = self.n_Conv_55(t_484)
|
||||
t_486 = F.relu(t_485)
|
||||
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||
t_487 = self.n_Conv_56(t_486_padded)
|
||||
t_488 = F.relu(t_487)
|
||||
t_489 = self.n_Conv_57(t_488)
|
||||
t_490 = torch.add(t_489, t_484)
|
||||
t_491 = F.relu(t_490)
|
||||
t_492 = self.n_Conv_58(t_491)
|
||||
t_493 = F.relu(t_492)
|
||||
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||
t_494 = self.n_Conv_59(t_493_padded)
|
||||
t_495 = F.relu(t_494)
|
||||
t_496 = self.n_Conv_60(t_495)
|
||||
t_497 = torch.add(t_496, t_491)
|
||||
t_498 = F.relu(t_497)
|
||||
t_499 = self.n_Conv_61(t_498)
|
||||
t_500 = F.relu(t_499)
|
||||
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||
t_501 = self.n_Conv_62(t_500_padded)
|
||||
t_502 = F.relu(t_501)
|
||||
t_503 = self.n_Conv_63(t_502)
|
||||
t_504 = torch.add(t_503, t_498)
|
||||
t_505 = F.relu(t_504)
|
||||
t_506 = self.n_Conv_64(t_505)
|
||||
t_507 = F.relu(t_506)
|
||||
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||
t_508 = self.n_Conv_65(t_507_padded)
|
||||
t_509 = F.relu(t_508)
|
||||
t_510 = self.n_Conv_66(t_509)
|
||||
t_511 = torch.add(t_510, t_505)
|
||||
t_512 = F.relu(t_511)
|
||||
t_513 = self.n_Conv_67(t_512)
|
||||
t_514 = F.relu(t_513)
|
||||
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||
t_515 = self.n_Conv_68(t_514_padded)
|
||||
t_516 = F.relu(t_515)
|
||||
t_517 = self.n_Conv_69(t_516)
|
||||
t_518 = torch.add(t_517, t_512)
|
||||
t_519 = F.relu(t_518)
|
||||
t_520 = self.n_Conv_70(t_519)
|
||||
t_521 = F.relu(t_520)
|
||||
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||
t_522 = self.n_Conv_71(t_521_padded)
|
||||
t_523 = F.relu(t_522)
|
||||
t_524 = self.n_Conv_72(t_523)
|
||||
t_525 = torch.add(t_524, t_519)
|
||||
t_526 = F.relu(t_525)
|
||||
t_527 = self.n_Conv_73(t_526)
|
||||
t_528 = F.relu(t_527)
|
||||
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||
t_529 = self.n_Conv_74(t_528_padded)
|
||||
t_530 = F.relu(t_529)
|
||||
t_531 = self.n_Conv_75(t_530)
|
||||
t_532 = torch.add(t_531, t_526)
|
||||
t_533 = F.relu(t_532)
|
||||
t_534 = self.n_Conv_76(t_533)
|
||||
t_535 = F.relu(t_534)
|
||||
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||
t_536 = self.n_Conv_77(t_535_padded)
|
||||
t_537 = F.relu(t_536)
|
||||
t_538 = self.n_Conv_78(t_537)
|
||||
t_539 = torch.add(t_538, t_533)
|
||||
t_540 = F.relu(t_539)
|
||||
t_541 = self.n_Conv_79(t_540)
|
||||
t_542 = F.relu(t_541)
|
||||
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||
t_543 = self.n_Conv_80(t_542_padded)
|
||||
t_544 = F.relu(t_543)
|
||||
t_545 = self.n_Conv_81(t_544)
|
||||
t_546 = torch.add(t_545, t_540)
|
||||
t_547 = F.relu(t_546)
|
||||
t_548 = self.n_Conv_82(t_547)
|
||||
t_549 = F.relu(t_548)
|
||||
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||
t_550 = self.n_Conv_83(t_549_padded)
|
||||
t_551 = F.relu(t_550)
|
||||
t_552 = self.n_Conv_84(t_551)
|
||||
t_553 = torch.add(t_552, t_547)
|
||||
t_554 = F.relu(t_553)
|
||||
t_555 = self.n_Conv_85(t_554)
|
||||
t_556 = F.relu(t_555)
|
||||
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||
t_557 = self.n_Conv_86(t_556_padded)
|
||||
t_558 = F.relu(t_557)
|
||||
t_559 = self.n_Conv_87(t_558)
|
||||
t_560 = torch.add(t_559, t_554)
|
||||
t_561 = F.relu(t_560)
|
||||
t_562 = self.n_Conv_88(t_561)
|
||||
t_563 = F.relu(t_562)
|
||||
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||
t_564 = self.n_Conv_89(t_563_padded)
|
||||
t_565 = F.relu(t_564)
|
||||
t_566 = self.n_Conv_90(t_565)
|
||||
t_567 = torch.add(t_566, t_561)
|
||||
t_568 = F.relu(t_567)
|
||||
t_569 = self.n_Conv_91(t_568)
|
||||
t_570 = F.relu(t_569)
|
||||
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||
t_571 = self.n_Conv_92(t_570_padded)
|
||||
t_572 = F.relu(t_571)
|
||||
t_573 = self.n_Conv_93(t_572)
|
||||
t_574 = torch.add(t_573, t_568)
|
||||
t_575 = F.relu(t_574)
|
||||
t_576 = self.n_Conv_94(t_575)
|
||||
t_577 = F.relu(t_576)
|
||||
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||
t_578 = self.n_Conv_95(t_577_padded)
|
||||
t_579 = F.relu(t_578)
|
||||
t_580 = self.n_Conv_96(t_579)
|
||||
t_581 = torch.add(t_580, t_575)
|
||||
t_582 = F.relu(t_581)
|
||||
t_583 = self.n_Conv_97(t_582)
|
||||
t_584 = F.relu(t_583)
|
||||
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||
t_585 = self.n_Conv_98(t_584_padded)
|
||||
t_586 = F.relu(t_585)
|
||||
t_587 = self.n_Conv_99(t_586)
|
||||
t_588 = self.n_Conv_100(t_582)
|
||||
t_589 = torch.add(t_587, t_588)
|
||||
t_590 = F.relu(t_589)
|
||||
t_591 = self.n_Conv_101(t_590)
|
||||
t_592 = F.relu(t_591)
|
||||
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||
t_593 = self.n_Conv_102(t_592_padded)
|
||||
t_594 = F.relu(t_593)
|
||||
t_595 = self.n_Conv_103(t_594)
|
||||
t_596 = torch.add(t_595, t_590)
|
||||
t_597 = F.relu(t_596)
|
||||
t_598 = self.n_Conv_104(t_597)
|
||||
t_599 = F.relu(t_598)
|
||||
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||
t_600 = self.n_Conv_105(t_599_padded)
|
||||
t_601 = F.relu(t_600)
|
||||
t_602 = self.n_Conv_106(t_601)
|
||||
t_603 = torch.add(t_602, t_597)
|
||||
t_604 = F.relu(t_603)
|
||||
t_605 = self.n_Conv_107(t_604)
|
||||
t_606 = F.relu(t_605)
|
||||
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||
t_607 = self.n_Conv_108(t_606_padded)
|
||||
t_608 = F.relu(t_607)
|
||||
t_609 = self.n_Conv_109(t_608)
|
||||
t_610 = torch.add(t_609, t_604)
|
||||
t_611 = F.relu(t_610)
|
||||
t_612 = self.n_Conv_110(t_611)
|
||||
t_613 = F.relu(t_612)
|
||||
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||
t_614 = self.n_Conv_111(t_613_padded)
|
||||
t_615 = F.relu(t_614)
|
||||
t_616 = self.n_Conv_112(t_615)
|
||||
t_617 = torch.add(t_616, t_611)
|
||||
t_618 = F.relu(t_617)
|
||||
t_619 = self.n_Conv_113(t_618)
|
||||
t_620 = F.relu(t_619)
|
||||
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||
t_621 = self.n_Conv_114(t_620_padded)
|
||||
t_622 = F.relu(t_621)
|
||||
t_623 = self.n_Conv_115(t_622)
|
||||
t_624 = torch.add(t_623, t_618)
|
||||
t_625 = F.relu(t_624)
|
||||
t_626 = self.n_Conv_116(t_625)
|
||||
t_627 = F.relu(t_626)
|
||||
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||
t_628 = self.n_Conv_117(t_627_padded)
|
||||
t_629 = F.relu(t_628)
|
||||
t_630 = self.n_Conv_118(t_629)
|
||||
t_631 = torch.add(t_630, t_625)
|
||||
t_632 = F.relu(t_631)
|
||||
t_633 = self.n_Conv_119(t_632)
|
||||
t_634 = F.relu(t_633)
|
||||
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||
t_635 = self.n_Conv_120(t_634_padded)
|
||||
t_636 = F.relu(t_635)
|
||||
t_637 = self.n_Conv_121(t_636)
|
||||
t_638 = torch.add(t_637, t_632)
|
||||
t_639 = F.relu(t_638)
|
||||
t_640 = self.n_Conv_122(t_639)
|
||||
t_641 = F.relu(t_640)
|
||||
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||
t_642 = self.n_Conv_123(t_641_padded)
|
||||
t_643 = F.relu(t_642)
|
||||
t_644 = self.n_Conv_124(t_643)
|
||||
t_645 = torch.add(t_644, t_639)
|
||||
t_646 = F.relu(t_645)
|
||||
t_647 = self.n_Conv_125(t_646)
|
||||
t_648 = F.relu(t_647)
|
||||
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||
t_649 = self.n_Conv_126(t_648_padded)
|
||||
t_650 = F.relu(t_649)
|
||||
t_651 = self.n_Conv_127(t_650)
|
||||
t_652 = torch.add(t_651, t_646)
|
||||
t_653 = F.relu(t_652)
|
||||
t_654 = self.n_Conv_128(t_653)
|
||||
t_655 = F.relu(t_654)
|
||||
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||
t_656 = self.n_Conv_129(t_655_padded)
|
||||
t_657 = F.relu(t_656)
|
||||
t_658 = self.n_Conv_130(t_657)
|
||||
t_659 = torch.add(t_658, t_653)
|
||||
t_660 = F.relu(t_659)
|
||||
t_661 = self.n_Conv_131(t_660)
|
||||
t_662 = F.relu(t_661)
|
||||
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||
t_663 = self.n_Conv_132(t_662_padded)
|
||||
t_664 = F.relu(t_663)
|
||||
t_665 = self.n_Conv_133(t_664)
|
||||
t_666 = torch.add(t_665, t_660)
|
||||
t_667 = F.relu(t_666)
|
||||
t_668 = self.n_Conv_134(t_667)
|
||||
t_669 = F.relu(t_668)
|
||||
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||
t_670 = self.n_Conv_135(t_669_padded)
|
||||
t_671 = F.relu(t_670)
|
||||
t_672 = self.n_Conv_136(t_671)
|
||||
t_673 = torch.add(t_672, t_667)
|
||||
t_674 = F.relu(t_673)
|
||||
t_675 = self.n_Conv_137(t_674)
|
||||
t_676 = F.relu(t_675)
|
||||
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||
t_677 = self.n_Conv_138(t_676_padded)
|
||||
t_678 = F.relu(t_677)
|
||||
t_679 = self.n_Conv_139(t_678)
|
||||
t_680 = torch.add(t_679, t_674)
|
||||
t_681 = F.relu(t_680)
|
||||
t_682 = self.n_Conv_140(t_681)
|
||||
t_683 = F.relu(t_682)
|
||||
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||
t_684 = self.n_Conv_141(t_683_padded)
|
||||
t_685 = F.relu(t_684)
|
||||
t_686 = self.n_Conv_142(t_685)
|
||||
t_687 = torch.add(t_686, t_681)
|
||||
t_688 = F.relu(t_687)
|
||||
t_689 = self.n_Conv_143(t_688)
|
||||
t_690 = F.relu(t_689)
|
||||
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||
t_691 = self.n_Conv_144(t_690_padded)
|
||||
t_692 = F.relu(t_691)
|
||||
t_693 = self.n_Conv_145(t_692)
|
||||
t_694 = torch.add(t_693, t_688)
|
||||
t_695 = F.relu(t_694)
|
||||
t_696 = self.n_Conv_146(t_695)
|
||||
t_697 = F.relu(t_696)
|
||||
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||
t_698 = self.n_Conv_147(t_697_padded)
|
||||
t_699 = F.relu(t_698)
|
||||
t_700 = self.n_Conv_148(t_699)
|
||||
t_701 = torch.add(t_700, t_695)
|
||||
t_702 = F.relu(t_701)
|
||||
t_703 = self.n_Conv_149(t_702)
|
||||
t_704 = F.relu(t_703)
|
||||
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||
t_705 = self.n_Conv_150(t_704_padded)
|
||||
t_706 = F.relu(t_705)
|
||||
t_707 = self.n_Conv_151(t_706)
|
||||
t_708 = torch.add(t_707, t_702)
|
||||
t_709 = F.relu(t_708)
|
||||
t_710 = self.n_Conv_152(t_709)
|
||||
t_711 = F.relu(t_710)
|
||||
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||
t_712 = self.n_Conv_153(t_711_padded)
|
||||
t_713 = F.relu(t_712)
|
||||
t_714 = self.n_Conv_154(t_713)
|
||||
t_715 = torch.add(t_714, t_709)
|
||||
t_716 = F.relu(t_715)
|
||||
t_717 = self.n_Conv_155(t_716)
|
||||
t_718 = F.relu(t_717)
|
||||
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||
t_719 = self.n_Conv_156(t_718_padded)
|
||||
t_720 = F.relu(t_719)
|
||||
t_721 = self.n_Conv_157(t_720)
|
||||
t_722 = torch.add(t_721, t_716)
|
||||
t_723 = F.relu(t_722)
|
||||
t_724 = self.n_Conv_158(t_723)
|
||||
t_725 = self.n_Conv_159(t_723)
|
||||
t_726 = F.relu(t_725)
|
||||
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||
t_727 = self.n_Conv_160(t_726_padded)
|
||||
t_728 = F.relu(t_727)
|
||||
t_729 = self.n_Conv_161(t_728)
|
||||
t_730 = torch.add(t_729, t_724)
|
||||
t_731 = F.relu(t_730)
|
||||
t_732 = self.n_Conv_162(t_731)
|
||||
t_733 = F.relu(t_732)
|
||||
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||
t_734 = self.n_Conv_163(t_733_padded)
|
||||
t_735 = F.relu(t_734)
|
||||
t_736 = self.n_Conv_164(t_735)
|
||||
t_737 = torch.add(t_736, t_731)
|
||||
t_738 = F.relu(t_737)
|
||||
t_739 = self.n_Conv_165(t_738)
|
||||
t_740 = F.relu(t_739)
|
||||
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||
t_741 = self.n_Conv_166(t_740_padded)
|
||||
t_742 = F.relu(t_741)
|
||||
t_743 = self.n_Conv_167(t_742)
|
||||
t_744 = torch.add(t_743, t_738)
|
||||
t_745 = F.relu(t_744)
|
||||
t_746 = self.n_Conv_168(t_745)
|
||||
t_747 = self.n_Conv_169(t_745)
|
||||
t_748 = F.relu(t_747)
|
||||
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||
t_749 = self.n_Conv_170(t_748_padded)
|
||||
t_750 = F.relu(t_749)
|
||||
t_751 = self.n_Conv_171(t_750)
|
||||
t_752 = torch.add(t_751, t_746)
|
||||
t_753 = F.relu(t_752)
|
||||
t_754 = self.n_Conv_172(t_753)
|
||||
t_755 = F.relu(t_754)
|
||||
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||
t_756 = self.n_Conv_173(t_755_padded)
|
||||
t_757 = F.relu(t_756)
|
||||
t_758 = self.n_Conv_174(t_757)
|
||||
t_759 = torch.add(t_758, t_753)
|
||||
t_760 = F.relu(t_759)
|
||||
t_761 = self.n_Conv_175(t_760)
|
||||
t_762 = F.relu(t_761)
|
||||
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||
t_763 = self.n_Conv_176(t_762_padded)
|
||||
t_764 = F.relu(t_763)
|
||||
t_765 = self.n_Conv_177(t_764)
|
||||
t_766 = torch.add(t_765, t_760)
|
||||
t_767 = F.relu(t_766)
|
||||
t_768 = self.n_Conv_178(t_767)
|
||||
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||
t_770 = torch.squeeze(t_769, 3)
|
||||
t_770 = torch.squeeze(t_770, 2)
|
||||
t_771 = torch.sigmoid(t_770)
|
||||
return t_771
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
self.tags = state_dict.get('tags', [])
|
||||
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||
|
||||
102
modules/devices.py
Executable file
102
modules/devices.py
Executable file
@@ -0,0 +1,102 @@
|
||||
import contextlib
|
||||
import torch
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def has_xpu() -> bool:
|
||||
return memory_management.xpu_available
|
||||
|
||||
|
||||
def has_mps() -> bool:
|
||||
return memory_management.mps_mode()
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return memory_management.get_torch_device().index
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
return str(memory_management.get_torch_device())
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
return memory_management.get_torch_device().type
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
return
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
return
|
||||
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = memory_management.get_torch_device()
|
||||
device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = torch.float32 if memory_management.unet_dtype() is torch.float32 else torch.float16
|
||||
dtype_vae: torch.dtype = memory_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = memory_management.unet_dtype()
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input
|
||||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = []
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
return
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
return
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
return
|
||||
|
||||
|
||||
def first_time_calculation():
|
||||
return
|
||||
150
modules/errors.py
Executable file
150
modules/errors.py
Executable file
@@ -0,0 +1,150 @@
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
||||
|
||||
exception_records = []
|
||||
|
||||
|
||||
def format_traceback(tb):
|
||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||
|
||||
|
||||
def format_exception(e, tb):
|
||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
return list(reversed(exception_records))
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def record_exception():
|
||||
_, e, tb = sys.exc_info()
|
||||
if e is None:
|
||||
return
|
||||
|
||||
if exception_records and exception_records[-1] == e:
|
||||
return
|
||||
|
||||
exception_records.append(format_exception(e, tb))
|
||||
|
||||
if len(exception_records) > 5:
|
||||
exception_records.pop(0)
|
||||
|
||||
|
||||
def report(message: str, *, exc_info: bool = False) -> None:
|
||||
"""
|
||||
Print an error message to stderr, with optional traceback.
|
||||
"""
|
||||
|
||||
record_exception()
|
||||
|
||||
for line in message.splitlines():
|
||||
print("***", line, file=sys.stderr)
|
||||
if exc_info:
|
||||
print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
|
||||
print("---", file=sys.stderr)
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
record_exception()
|
||||
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
for line in lines:
|
||||
print(line, file=sys.stderr)
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task, *, full_traceback=False):
|
||||
record_exception()
|
||||
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
te = traceback.TracebackException.from_exception(e)
|
||||
if full_traceback:
|
||||
# include frames leading up to the try-catch block
|
||||
te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
||||
print(*te.format(), sep="", file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
already_displayed = {}
|
||||
|
||||
|
||||
def display_once(e: Exception, task):
|
||||
record_exception()
|
||||
|
||||
if task in already_displayed:
|
||||
return
|
||||
|
||||
display(e, task)
|
||||
|
||||
already_displayed[task] = 1
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
except Exception as e:
|
||||
display(task, e)
|
||||
|
||||
|
||||
def check_versions():
|
||||
from packaging import version
|
||||
from modules import shared
|
||||
|
||||
import torch
|
||||
import gradio
|
||||
|
||||
expected_torch_version = "2.3.1"
|
||||
expected_xformers_version = "0.0.27"
|
||||
expected_gradio_version = "4.40.0"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if gradio.__version__ != expected_gradio_version:
|
||||
print_error_explanation(f"""
|
||||
You are running gradio {gradio.__version__}.
|
||||
The program is designed to work with gradio {expected_gradio_version}.
|
||||
Using a different version of gradio is extremely likely to break the program.
|
||||
|
||||
Reasons why you have the mismatched gradio version can be:
|
||||
- you use --skip-install flag.
|
||||
- you use webui.py to start the program instead of launch.py.
|
||||
- an extension installs the incompatible gradio version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
64
modules/esrgan_model.py
Executable file
64
modules/esrgan_model.py
Executable file
@@ -0,0 +1,64 @@
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "ESRGAN"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
||||
self.model_name = "ESRGAN_4x"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth", ".safetensors"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if file.startswith("http"):
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
|
||||
scaler_data = UpscalerData(name, file, self, 4)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
return esrgan_upscale(model, img)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if path.startswith("http"):
|
||||
# TODO: this doesn't use `path` at all?
|
||||
filename = modelloader.load_file_from_url(
|
||||
url=self.model_url,
|
||||
model_dir=self.model_download_path,
|
||||
file_name=f"{self.model_name}.pth",
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
|
||||
return modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||
expected_architecture='ESRGAN',
|
||||
)
|
||||
|
||||
|
||||
def esrgan_upscale(model, img):
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
318
modules/extensions.py
Executable file
318
modules/extensions.py
Executable file
@@ -0,0 +1,318 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import configparser
|
||||
import dataclasses
|
||||
import os
|
||||
import threading
|
||||
import re
|
||||
import json
|
||||
|
||||
from modules import shared, errors, cache, scripts
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
from modules_forge.config import always_disabled_extensions
|
||||
|
||||
extensions: list[Extension] = []
|
||||
extension_paths: dict[str, Extension] = {}
|
||||
loaded_extensions: dict[str, Exception] = {}
|
||||
|
||||
|
||||
os.makedirs(extensions_dir, exist_ok=True)
|
||||
|
||||
|
||||
def active():
|
||||
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||
else:
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CallbackOrderInfo:
|
||||
name: str
|
||||
before: list
|
||||
after: list
|
||||
|
||||
|
||||
class ExtensionMetadata:
|
||||
filename = "metadata.ini"
|
||||
config: configparser.ConfigParser
|
||||
canonical_name: str
|
||||
requires: list
|
||||
|
||||
def __init__(self, path, canonical_name):
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
filepath = os.path.join(path, self.filename)
|
||||
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
|
||||
# so no need to check whether the file exists beforehand.
|
||||
try:
|
||||
self.config.read(filepath)
|
||||
except Exception:
|
||||
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||
|
||||
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||
self.canonical_name = canonical_name.lower().strip()
|
||||
|
||||
self.requires = None
|
||||
|
||||
def get_script_requirements(self, field, section, extra_section=None):
|
||||
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
||||
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
|
||||
reads more requirements from [extra_section] if specified."""
|
||||
|
||||
x = self.config.get(section, field, fallback='')
|
||||
|
||||
if extra_section:
|
||||
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
||||
|
||||
listed_requirements = self.parse_list(x.lower())
|
||||
res = []
|
||||
|
||||
for requirement in listed_requirements:
|
||||
loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)
|
||||
relevant_requirement = next(loaded_requirements, requirement)
|
||||
res.append(relevant_requirement)
|
||||
|
||||
return res
|
||||
|
||||
def parse_list(self, text):
|
||||
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
||||
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# both "," and " " are accepted as separator
|
||||
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||
|
||||
def list_callback_order_instructions(self):
|
||||
for section in self.config.sections():
|
||||
if not section.startswith("callbacks/"):
|
||||
continue
|
||||
|
||||
callback_name = section[10:]
|
||||
|
||||
if not callback_name.startswith(self.canonical_name):
|
||||
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
|
||||
continue
|
||||
|
||||
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
|
||||
after = self.parse_list(self.config.get(section, 'After', fallback=''))
|
||||
|
||||
yield CallbackOrderInfo(callback_name, before, after)
|
||||
|
||||
|
||||
class Extension:
|
||||
lock = threading.Lock()
|
||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||
metadata: ExtensionMetadata
|
||||
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
self.commit_hash = ''
|
||||
self.commit_date = None
|
||||
self.version = ''
|
||||
self.branch = None
|
||||
self.remote = None
|
||||
self.have_info_from_repo = False
|
||||
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
||||
self.canonical_name = metadata.canonical_name
|
||||
|
||||
self.is_forge_space = False
|
||||
self.space_meta = None
|
||||
|
||||
if os.path.exists(os.path.join(self.path, 'space_meta.json')) and os.path.exists(os.path.join(self.path, 'forge_app.py')):
|
||||
self.is_forge_space = True
|
||||
self.space_meta = json.load(open(os.path.join(self.path, 'space_meta.json'), 'rt', encoding='utf-8'))
|
||||
|
||||
def to_dict(self):
|
||||
return {x: getattr(self, x) for x in self.cached_fields}
|
||||
|
||||
def from_dict(self, d):
|
||||
for field in self.cached_fields:
|
||||
setattr(self, field, d[field])
|
||||
|
||||
def read_info_from_repo(self):
|
||||
if self.is_builtin or self.have_info_from_repo:
|
||||
return
|
||||
|
||||
def read_from_repo():
|
||||
with self.lock:
|
||||
if self.have_info_from_repo:
|
||||
return
|
||||
|
||||
self.do_read_info_from_repo()
|
||||
|
||||
return self.to_dict()
|
||||
|
||||
try:
|
||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||
self.from_dict(d)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self.status = 'unknown' if self.status == '' else self.status
|
||||
|
||||
def do_read_info_from_repo(self):
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(self.path, ".git")):
|
||||
repo = Repo(self.path)
|
||||
except Exception:
|
||||
errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
commit = repo.head.commit
|
||||
self.commit_date = commit.committed_date
|
||||
if repo.active_branch:
|
||||
self.branch = repo.active_branch.name
|
||||
self.commit_hash = commit.hexsha
|
||||
self.version = self.commit_hash[:8]
|
||||
|
||||
except Exception:
|
||||
errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
|
||||
self.remote = None
|
||||
|
||||
self.have_info_from_repo = True
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = Repo(self.path)
|
||||
branch_name = f'{repo.remote().name}/{self.branch}'
|
||||
for fetch in repo.remote().fetch(dry_run=True):
|
||||
if self.branch and fetch.name != branch_name:
|
||||
continue
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "new commits"
|
||||
return
|
||||
|
||||
try:
|
||||
origin = repo.rev_parse(branch_name)
|
||||
if repo.head.commit != origin:
|
||||
self.can_update = True
|
||||
self.status = "behind HEAD"
|
||||
return
|
||||
except Exception:
|
||||
self.can_update = False
|
||||
self.status = "unknown (remote error)"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self, commit=None):
|
||||
repo = Repo(self.path)
|
||||
if commit is None:
|
||||
commit = f'{repo.remote().name}/{self.branch}'
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch(all=True)
|
||||
repo.git.reset(commit, hard=True)
|
||||
self.have_info_from_repo = False
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
extension_paths.clear()
|
||||
loaded_extensions.clear()
|
||||
|
||||
if shared.cmd_opts.disable_all_extensions:
|
||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "all":
|
||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||
elif shared.cmd_opts.disable_extra_extensions:
|
||||
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||
|
||||
|
||||
# scan through extensions directory and load metadata
|
||||
for dirname in [extensions_builtin_dir, extensions_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
continue
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
canonical_name = extension_dirname
|
||||
metadata = ExtensionMetadata(path, canonical_name)
|
||||
|
||||
# check for duplicated canonical names
|
||||
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
|
||||
if already_loaded_extension is not None:
|
||||
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
|
||||
continue
|
||||
|
||||
is_builtin = dirname == extensions_builtin_dir
|
||||
|
||||
disabled_extensions = shared.opts.disabled_extensions + always_disabled_extensions
|
||||
|
||||
extension = Extension(
|
||||
name=extension_dirname,
|
||||
path=path,
|
||||
enabled=extension_dirname not in disabled_extensions,
|
||||
is_builtin=is_builtin,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
extensions.append(extension)
|
||||
extension_paths[extension.path] = extension
|
||||
loaded_extensions[canonical_name] = extension
|
||||
|
||||
for extension in extensions:
|
||||
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
|
||||
|
||||
# check for requirements
|
||||
for extension in extensions:
|
||||
if not extension.enabled:
|
||||
continue
|
||||
|
||||
for req in extension.metadata.requires:
|
||||
required_extension = loaded_extensions.get(req)
|
||||
if required_extension is None:
|
||||
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||
continue
|
||||
|
||||
if not required_extension.enabled:
|
||||
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||
continue
|
||||
|
||||
|
||||
def find_extension(filename):
|
||||
parentdir = os.path.dirname(os.path.realpath(filename))
|
||||
|
||||
while parentdir != filename:
|
||||
extension = extension_paths.get(parentdir)
|
||||
if extension is not None:
|
||||
return extension
|
||||
|
||||
filename = parentdir
|
||||
parentdir = os.path.dirname(filename)
|
||||
|
||||
return None
|
||||
|
||||
225
modules/extra_networks.py
Executable file
225
modules/extra_networks.py
Executable file
@@ -0,0 +1,225 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from modules import errors
|
||||
|
||||
extra_network_registry = {}
|
||||
extra_network_aliases = {}
|
||||
|
||||
|
||||
def initialize():
|
||||
extra_network_registry.clear()
|
||||
extra_network_aliases.clear()
|
||||
|
||||
|
||||
def register_extra_network(extra_network):
|
||||
extra_network_registry[extra_network.name] = extra_network
|
||||
|
||||
|
||||
def register_extra_network_alias(extra_network, alias):
|
||||
extra_network_aliases[alias] = extra_network
|
||||
|
||||
|
||||
def register_default_extra_networks():
|
||||
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
||||
register_extra_network(ExtraNetworkHypernet())
|
||||
|
||||
|
||||
class ExtraNetworkParams:
|
||||
def __init__(self, items=None):
|
||||
self.items = items or []
|
||||
self.positional = []
|
||||
self.named = {}
|
||||
|
||||
for item in self.items:
|
||||
parts = item.split('=', 2) if isinstance(item, str) else [item]
|
||||
if len(parts) == 2:
|
||||
self.named[parts[0]] = parts[1]
|
||||
else:
|
||||
self.positional.append(item)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.items == other.items
|
||||
|
||||
|
||||
class ExtraNetwork:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def activate(self, p, params_list):
|
||||
"""
|
||||
Called by processing on every run. Whatever the extra network is meant to do should be activated here.
|
||||
Passes arguments related to this extra network in params_list.
|
||||
User passes arguments by specifying this in his prompt:
|
||||
|
||||
<name:arg1:arg2:arg3>
|
||||
|
||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||
separated by colon.
|
||||
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
|
||||
in this case, all effects of this extra networks should be disabled.
|
||||
|
||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||
|
||||
For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
|
||||
|
||||
> "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
|
||||
|
||||
params_list will be:
|
||||
|
||||
[
|
||||
ExtraNetworkParams(items=["agm", "1.1"]),
|
||||
ExtraNetworkParams(items=["ray"])
|
||||
]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def deactivate(self, p):
|
||||
"""
|
||||
Called at the end of processing for housekeeping. No need to do anything here.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def lookup_extra_networks(extra_network_data):
|
||||
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
|
||||
|
||||
Example input:
|
||||
{
|
||||
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
|
||||
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||
}
|
||||
|
||||
Example output:
|
||||
|
||||
{
|
||||
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||
}
|
||||
"""
|
||||
|
||||
res = {}
|
||||
|
||||
for extra_network_name, extra_network_args in list(extra_network_data.items()):
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
alias = extra_network_aliases.get(extra_network_name, None)
|
||||
|
||||
if alias is not None and extra_network is None:
|
||||
extra_network = alias
|
||||
|
||||
if extra_network is None:
|
||||
logging.info(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
res.setdefault(extra_network, []).extend(extra_network_args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
activated = []
|
||||
|
||||
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
activated.append(extra_network)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
if extra_network in activated:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.activate(p, [])
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name}")
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
|
||||
|
||||
|
||||
def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
data = lookup_extra_networks(extra_network_data)
|
||||
|
||||
for extra_network in data:
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating extra network {extra_network.name}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
if extra_network in data:
|
||||
continue
|
||||
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
|
||||
|
||||
|
||||
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||
|
||||
|
||||
def parse_prompt(prompt):
|
||||
res = defaultdict(list)
|
||||
|
||||
def found(m):
|
||||
name = m.group(1)
|
||||
args = m.group(2)
|
||||
|
||||
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||
|
||||
return ""
|
||||
|
||||
prompt = re.sub(re_extra_net, found, prompt)
|
||||
|
||||
return prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts):
|
||||
res = []
|
||||
extra_data = None
|
||||
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
|
||||
if extra_data is None:
|
||||
extra_data = parsed_extra_data
|
||||
|
||||
res.append(updated_prompt)
|
||||
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename, lister=None):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
|
||||
if exists:
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
|
||||
return metadata
|
||||
28
modules/extra_networks_hypernet.py
Executable file
28
modules/extra_networks_hypernet.py
Executable file
@@ -0,0 +1,28 @@
|
||||
from modules import extra_networks, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('hypernet')
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_hypernetwork
|
||||
|
||||
if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
|
||||
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
|
||||
names = []
|
||||
multipliers = []
|
||||
for params in params_list:
|
||||
assert params.items
|
||||
|
||||
names.append(params.items[0])
|
||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
||||
|
||||
hypernetwork.load_hypernetworks(names, multipliers)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
331
modules/extras.py
Executable file
331
modules/extras.py
Executable file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import json
|
||||
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||
from modules.ui_common import plaintext_to_html
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
return '', '', ''
|
||||
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
info = ''
|
||||
for key, text in items.items():
|
||||
info += f"""
|
||||
<div>
|
||||
<p><b>{plaintext_to_html(str(key))}</b></p>
|
||||
<p>{plaintext_to_html(str(text))}</p>
|
||||
</div>
|
||||
""".strip()+"\n"
|
||||
|
||||
if len(info) == 0:
|
||||
message = "Nothing found in the image."
|
||||
info = f"<div><p>{message}<p></div>"
|
||||
|
||||
return '', geninfo, info
|
||||
|
||||
|
||||
def create_config(ckpt_result, config_source, a, b, c):
|
||||
def config(x):
|
||||
res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None
|
||||
return res if res != shared.sd_default_config else None
|
||||
|
||||
if config_source == 0:
|
||||
cfg = config(a) or config(b) or config(c)
|
||||
elif config_source == 1:
|
||||
cfg = config(b)
|
||||
elif config_source == 2:
|
||||
cfg = config(c)
|
||||
else:
|
||||
cfg = None
|
||||
|
||||
if cfg is None:
|
||||
return
|
||||
|
||||
filename, _ = os.path.splitext(ckpt_result)
|
||||
checkpoint_filename = filename + ".yaml"
|
||||
|
||||
print("Copying config:")
|
||||
print(" from:", cfg)
|
||||
print(" to:", checkpoint_filename)
|
||||
shutil.copyfile(cfg, checkpoint_filename)
|
||||
|
||||
|
||||
checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||
|
||||
|
||||
def to_half(tensor, enable):
|
||||
if enable and tensor.dtype == torch.float:
|
||||
return tensor.half()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||
metadata = {}
|
||||
|
||||
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||
if checkpoint_info is None:
|
||||
continue
|
||||
|
||||
metadata.update(checkpoint_info.metadata)
|
||||
|
||||
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
shared.state.textinfo = message
|
||||
shared.state.end()
|
||||
return [*[gr.update() for _ in range(4)], message]
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
def get_difference(theta1, theta2):
|
||||
return theta1 - theta2
|
||||
|
||||
def add_difference(theta0, theta1_2_diff, alpha):
|
||||
return theta0 + (alpha * theta1_2_diff)
|
||||
|
||||
def filename_weighted_sum():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
Ma = round(1 - multiplier, 2)
|
||||
Mb = round(multiplier, 2)
|
||||
|
||||
return f"{Ma}({a}) + {Mb}({b})"
|
||||
|
||||
def filename_add_difference():
|
||||
a = primary_model_info.model_name
|
||||
b = secondary_model_info.model_name
|
||||
c = tertiary_model_info.model_name
|
||||
M = round(multiplier, 2)
|
||||
|
||||
return f"{a} + {M}({b} - {c})"
|
||||
|
||||
def filename_nothing():
|
||||
return primary_model_info.model_name
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (filename_weighted_sum, None, weighted_sum),
|
||||
"Add difference": (filename_add_difference, get_difference, add_difference),
|
||||
"No interpolation": (filename_nothing, None, None),
|
||||
}
|
||||
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
||||
|
||||
if not primary_model_name:
|
||||
return fail("Failed: Merging requires a primary model.")
|
||||
|
||||
primary_model_info = sd_models.checkpoint_aliases[primary_model_name]
|
||||
|
||||
if theta_func2 and not secondary_model_name:
|
||||
return fail("Failed: Merging requires a secondary model.")
|
||||
|
||||
secondary_model_info = sd_models.checkpoint_aliases[secondary_model_name] if theta_func2 else None
|
||||
|
||||
if theta_func1 and not tertiary_model_name:
|
||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||
|
||||
tertiary_model_info = sd_models.checkpoint_aliases[tertiary_model_name] if theta_func1 else None
|
||||
|
||||
result_is_inpainting_model = False
|
||||
result_is_instruct_pix2pix_model = False
|
||||
|
||||
if theta_func2:
|
||||
shared.state.textinfo = "Loading B"
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.load_torch_file(secondary_model_info.filename)
|
||||
else:
|
||||
theta_1 = None
|
||||
|
||||
if theta_func1:
|
||||
shared.state.textinfo = "Loading C"
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.load_torch_file(tertiary_model_info.filename)
|
||||
|
||||
shared.state.textinfo = 'Merging B and C'
|
||||
shared.state.sampling_steps = len(theta_1.keys())
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
del theta_2
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.load_torch_file(primary_model_info.filename)
|
||||
|
||||
print("Merging...")
|
||||
shared.state.textinfo = 'Merging A and B'
|
||||
shared.state.sampling_steps = len(theta_0.keys())
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if theta_1 and 'model' in key and key in theta_1:
|
||||
|
||||
if key in checkpoint_dict_skip_on_merge:
|
||||
continue
|
||||
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
||||
if a.shape[1] == 4 and b.shape[1] == 9:
|
||||
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
||||
if a.shape[1] == 4 and b.shape[1] == 8:
|
||||
raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
|
||||
|
||||
if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
|
||||
result_is_instruct_pix2pix_model = True
|
||||
else:
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
shared.state.sampling_step += 1
|
||||
|
||||
del theta_1
|
||||
|
||||
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||
if bake_in_vae_filename is not None:
|
||||
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||
shared.state.textinfo = 'Baking in VAE'
|
||||
vae_dict = sd_vae.load_torch_file(bake_in_vae_filename)
|
||||
|
||||
for key in vae_dict.keys():
|
||||
theta_0_key = 'first_stage_model.' + key
|
||||
if theta_0_key in theta_0:
|
||||
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
||||
|
||||
del vae_dict
|
||||
|
||||
if save_as_half and not theta_func2:
|
||||
for key in theta_0.keys():
|
||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||
|
||||
if discard_weights:
|
||||
regex = re.compile(discard_weights)
|
||||
for key in list(theta_0):
|
||||
if re.search(regex, key):
|
||||
theta_0.pop(key, None)
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = filename_generator() if custom_name == '' else custom_name
|
||||
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||
filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
|
||||
filename += "." + checkpoint_format
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
metadata = {}
|
||||
|
||||
if save_metadata and copy_metadata_fields:
|
||||
if primary_model_info:
|
||||
metadata.update(primary_model_info.metadata)
|
||||
if secondary_model_info:
|
||||
metadata.update(secondary_model_info.metadata)
|
||||
if tertiary_model_info:
|
||||
metadata.update(tertiary_model_info.metadata)
|
||||
|
||||
if save_metadata:
|
||||
try:
|
||||
metadata.update(json.loads(metadata_json))
|
||||
except Exception as e:
|
||||
errors.display(e, "readin metadata from json")
|
||||
|
||||
metadata["format"] = "pt"
|
||||
|
||||
if save_metadata and add_merge_recipe:
|
||||
merge_recipe = {
|
||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||
"primary_model_hash": primary_model_info.sha256,
|
||||
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
||||
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
||||
"interp_method": interp_method,
|
||||
"multiplier": multiplier,
|
||||
"save_as_half": save_as_half,
|
||||
"custom_name": custom_name,
|
||||
"config_source": config_source,
|
||||
"bake_in_vae": bake_in_vae,
|
||||
"discard_weights": discard_weights,
|
||||
"is_inpainting": result_is_inpainting_model,
|
||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||
}
|
||||
|
||||
sd_merge_models = {}
|
||||
|
||||
def add_model_metadata(checkpoint_info):
|
||||
checkpoint_info.calculate_shorthash()
|
||||
sd_merge_models[checkpoint_info.sha256] = {
|
||||
"name": checkpoint_info.name,
|
||||
"legacy_hash": checkpoint_info.hash,
|
||||
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||
}
|
||||
|
||||
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||
|
||||
add_model_metadata(primary_model_info)
|
||||
if secondary_model_info:
|
||||
add_model_metadata(secondary_model_info)
|
||||
if tertiary_model_info:
|
||||
add_model_metadata(tertiary_model_info)
|
||||
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
sd_models.list_models()
|
||||
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
||||
if created_model:
|
||||
created_model.calculate_shorthash()
|
||||
|
||||
# TODO inside create_config() sd_models_config.find_checkpoint_config_near_filename() is called which has been commented out
|
||||
#create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||
|
||||
print(f"Checkpoint saved to {output_modelname}.")
|
||||
shared.state.textinfo = "Checkpoint saved"
|
||||
shared.state.end()
|
||||
|
||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||
19
modules/face_restoration.py
Executable file
19
modules/face_restoration.py
Executable file
@@ -0,0 +1,19 @@
|
||||
from modules import shared
|
||||
|
||||
|
||||
class FaceRestoration:
|
||||
def name(self):
|
||||
return "None"
|
||||
|
||||
def restore(self, np_image):
|
||||
return np_image
|
||||
|
||||
|
||||
def restore_faces(np_image):
|
||||
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
|
||||
if len(face_restorers) == 0:
|
||||
return np_image
|
||||
|
||||
face_restorer = face_restorers[0]
|
||||
|
||||
return face_restorer.restore(np_image)
|
||||
182
modules/face_restoration_utils.py
Executable file
182
modules/face_restoration_utils.py
Executable file
@@ -0,0 +1,182 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
||||
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
||||
assert img.shape[2] == 3, "image must be RGB"
|
||||
if img.dtype == "float64":
|
||||
img = img.astype("float32")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
|
||||
|
||||
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
||||
"""
|
||||
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
||||
"""
|
||||
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
assert tensor.dim() == 3, "tensor must be RGB"
|
||||
img_np = tensor.numpy().transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
||||
return np.squeeze(img_np, axis=2)
|
||||
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
def create_face_helper(device) -> FaceRestoreHelper:
|
||||
from facexlib.detection import retinaface
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = device
|
||||
return FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def restore_with_face_helper(
|
||||
np_image: np.ndarray,
|
||||
face_helper: FaceRestoreHelper,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||
|
||||
`restore_face` should take a cropped face image and return a restored face image.
|
||||
"""
|
||||
from torchvision.transforms.functional import normalize
|
||||
np_image = np_image[:, :, ::-1]
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
try:
|
||||
logger.debug("Detecting faces...")
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np_image)
|
||||
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||
for cropped_face in face_helper.cropped_faces:
|
||||
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
cropped_face_t = restore_face(cropped_face_t)
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed face-restoration inference', exc_info=True)
|
||||
|
||||
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
||||
restored_face = (restored_face * 255.0).astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
logger.debug("Merging restored faces into image")
|
||||
face_helper.get_inverse_affine(None)
|
||||
img = face_helper.paste_faces_to_input_image()
|
||||
img = img[:, :, ::-1]
|
||||
if original_resolution != img.shape[0:2]:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(0, 0),
|
||||
fx=original_resolution[1] / img.shape[1],
|
||||
fy=original_resolution[0] / img.shape[0],
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
logger.debug("Face restoration complete")
|
||||
finally:
|
||||
face_helper.clean_all()
|
||||
return img
|
||||
|
||||
|
||||
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
net: torch.Module | None
|
||||
model_url: str
|
||||
model_download_name: str
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__()
|
||||
self.net = None
|
||||
self.model_path = model_path
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
@cached_property
|
||||
def face_helper(self) -> FaceRestoreHelper:
|
||||
return create_face_helper(self.get_device())
|
||||
|
||||
def send_model_to(self, device):
|
||||
if self.net:
|
||||
logger.debug("Sending %s to %s", self.net, device)
|
||||
self.net.to(device)
|
||||
if self.face_helper:
|
||||
logger.debug("Sending face helper to %s", device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def get_device(self):
|
||||
raise NotImplementedError("get_device must be implemented by subclasses")
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
raise NotImplementedError("load_net must be implemented by subclasses")
|
||||
|
||||
def restore_with_helper(
|
||||
self,
|
||||
np_image: np.ndarray,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
if self.net is None:
|
||||
self.net = self.load_net()
|
||||
except Exception:
|
||||
logger.warning("Unable to load face-restoration model", exc_info=True)
|
||||
return np_image
|
||||
|
||||
try:
|
||||
prepare_free_memory()
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
|
||||
def patch_facexlib(dirname: str) -> None:
|
||||
import facexlib.detection
|
||||
import facexlib.parsing
|
||||
|
||||
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
||||
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
||||
|
||||
def update_kwargs(kwargs):
|
||||
return dict(kwargs, save_dir=dirname, model_dir=None)
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
37
modules/fifo_lock.py
Executable file
37
modules/fifo_lock.py
Executable file
@@ -0,0 +1,37 @@
|
||||
import threading
|
||||
import collections
|
||||
|
||||
|
||||
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
||||
class FIFOLock(object):
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._inner_lock = threading.Lock()
|
||||
self._pending_threads = collections.deque()
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
with self._inner_lock:
|
||||
lock_acquired = self._lock.acquire(False)
|
||||
if lock_acquired:
|
||||
return True
|
||||
elif not blocking:
|
||||
return False
|
||||
|
||||
release_event = threading.Event()
|
||||
self._pending_threads.append(release_event)
|
||||
|
||||
release_event.wait()
|
||||
return self._lock.acquire()
|
||||
|
||||
def release(self):
|
||||
with self._inner_lock:
|
||||
if self._pending_threads:
|
||||
release_event = self._pending_threads.popleft()
|
||||
release_event.set()
|
||||
|
||||
self._lock.release()
|
||||
|
||||
__enter__ = acquire
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
self.release()
|
||||
80
modules/gfpgan_model.py
Executable file
80
modules/gfpgan_model.py
Executable file
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
model_download_name = "GFPGANv1.4.pth"
|
||||
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
def get_device(self):
|
||||
return devices.device_gfpgan
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
if 'GFPGAN' in os.path.basename(model_path):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
|
||||
# if reach here, model not found. previous code will download it iff there are no models in GFPGAN directory
|
||||
# this will download it if the supporting models exist
|
||||
try:
|
||||
GFPGANmodel = modelloader.load_file_from_url(model_url, model_dir=self.model_path, file_name=model_download_name)
|
||||
return modelloader.load_spandrel_model(
|
||||
GFPGANmodel,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
except:
|
||||
raise ValueError("No GFPGAN model found")
|
||||
|
||||
def restore(self, np_image):
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, return_rgb=False)[0]
|
||||
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
if gfpgan_face_restorer:
|
||||
return gfpgan_face_restorer.restore(np_image)
|
||||
logger.warning("GFPGAN face restorer not set up")
|
||||
return np_image
|
||||
|
||||
|
||||
def setup_model(dirname: str) -> None:
|
||||
global gfpgan_face_restorer
|
||||
|
||||
try:
|
||||
face_restoration_utils.patch_facexlib(dirname)
|
||||
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
||||
shared.face_restorers.append(gfpgan_face_restorer)
|
||||
except Exception:
|
||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||
42
modules/gitpython_hack.py
Executable file
42
modules/gitpython_hack.py
Executable file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
|
||||
import git
|
||||
|
||||
|
||||
class Git(git.Git):
|
||||
"""
|
||||
Git subclassed to never use persistent processes.
|
||||
"""
|
||||
|
||||
def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
|
||||
raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
|
||||
|
||||
def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
|
||||
ret = subprocess.check_output(
|
||||
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
|
||||
input=self._prepare_ref(ref),
|
||||
cwd=self._working_dir,
|
||||
timeout=2,
|
||||
)
|
||||
return self._parse_object_header(ret)
|
||||
|
||||
def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
|
||||
# Not really streaming, per se; this buffers the entire object in memory.
|
||||
# Shouldn't be a problem for our use case, since we're only using this for
|
||||
# object headers (commit objects).
|
||||
ret = subprocess.check_output(
|
||||
[self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
|
||||
input=self._prepare_ref(ref),
|
||||
cwd=self._working_dir,
|
||||
timeout=30,
|
||||
)
|
||||
bio = io.BytesIO(ret)
|
||||
hexsha, typename, size = self._parse_object_header(bio.readline())
|
||||
return (hexsha, typename, size, self.CatFileContentStream(size, bio))
|
||||
|
||||
|
||||
class Repo(git.Repo):
|
||||
GitCommandWrapperType = Git
|
||||
181
modules/gradio_extensions.py
Executable file
181
modules/gradio_extensions.py
Executable file
@@ -0,0 +1,181 @@
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
import gradio as gr
|
||||
import gradio.component_meta
|
||||
|
||||
|
||||
from modules import scripts, ui_tempdir, patches
|
||||
|
||||
|
||||
class GradioDeprecationWarning(DeprecationWarning):
|
||||
pass
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(getattr(comp, 'elem_classes', None) or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def Block_get_config(self):
|
||||
config = original_Block_get_config(self)
|
||||
|
||||
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||
if webui_tooltip:
|
||||
config["webui_tooltip"] = webui_tooltip
|
||||
|
||||
config.pop('example_inputs', None)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def Blocks_get_config_file(self, *args, **kwargs):
|
||||
config = original_Blocks_get_config_file(self, *args, **kwargs)
|
||||
|
||||
for comp_config in config["components"]:
|
||||
if "example_inputs" in comp_config:
|
||||
comp_config["example_inputs"] = {"serialized": []}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
original_IOComponent_init = patches.patch(__name__, obj=gr.components.Component, field="__init__", replacement=IOComponent_init)
|
||||
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||
|
||||
|
||||
ui_tempdir.install_ui_tempdir_override()
|
||||
|
||||
|
||||
def gradio_component_meta_create_or_modify_pyi(component_class, class_name, events):
|
||||
if hasattr(component_class, 'webui_do_not_create_gradio_pyi_thank_you'):
|
||||
return
|
||||
|
||||
gradio_component_meta_create_or_modify_pyi_original(component_class, class_name, events)
|
||||
|
||||
|
||||
# this prevents creation of .pyi files in webui dir
|
||||
gradio_component_meta_create_or_modify_pyi_original = patches.patch(__file__, gradio.component_meta, 'create_or_modify_pyi', gradio_component_meta_create_or_modify_pyi)
|
||||
|
||||
# this function is broken and does not seem to do anything useful
|
||||
gradio.component_meta.updateable = lambda x: x
|
||||
|
||||
|
||||
class EventWrapper:
|
||||
def __init__(self, replaced_event):
|
||||
self.replaced_event = replaced_event
|
||||
self.has_trigger = getattr(replaced_event, 'has_trigger', None)
|
||||
self.event_name = getattr(replaced_event, 'event_name', None)
|
||||
self.callback = getattr(replaced_event, 'callback', None)
|
||||
self.real_self = getattr(replaced_event, '__self__', None)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if '_js' in kwargs:
|
||||
kwargs['js'] = kwargs['_js']
|
||||
del kwargs['_js']
|
||||
return self.replaced_event(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def __self__(self):
|
||||
return self.real_self
|
||||
|
||||
|
||||
def repair(grclass):
|
||||
if not getattr(grclass, 'EVENTS', None):
|
||||
return
|
||||
|
||||
@wraps(grclass.__init__)
|
||||
def __repaired_init__(self, *args, tooltip=None, source=None, original=grclass.__init__, **kwargs):
|
||||
if source:
|
||||
kwargs["sources"] = [source]
|
||||
|
||||
allowed_kwargs = inspect.signature(original).parameters
|
||||
fixed_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if k in allowed_kwargs:
|
||||
fixed_kwargs[k] = v
|
||||
else:
|
||||
warnings.warn(f"unexpected argument for {grclass.__name__}: {k}", GradioDeprecationWarning, stacklevel=2)
|
||||
|
||||
original(self, *args, **fixed_kwargs)
|
||||
|
||||
self.webui_tooltip = tooltip
|
||||
|
||||
for event in self.EVENTS:
|
||||
replaced_event = getattr(self, str(event))
|
||||
fun = EventWrapper(replaced_event)
|
||||
setattr(self, str(event), fun)
|
||||
|
||||
grclass.__init__ = __repaired_init__
|
||||
grclass.update = gr.update
|
||||
|
||||
|
||||
for component in set(gr.components.__all__ + gr.layouts.__all__):
|
||||
repair(getattr(gr, component, None))
|
||||
|
||||
|
||||
class Dependency(gr.events.Dependency):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def then(*xargs, _js=None, **xkwargs):
|
||||
if _js:
|
||||
xkwargs['js'] = _js
|
||||
|
||||
return original_then(*xargs, **xkwargs)
|
||||
|
||||
original_then = self.then
|
||||
self.then = then
|
||||
|
||||
|
||||
gr.events.Dependency = Dependency
|
||||
|
||||
gr.Box = gr.Group
|
||||
|
||||
92
modules/hashes.py
Executable file
92
modules/hashes.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import hashlib
|
||||
import os.path
|
||||
|
||||
from modules import shared
|
||||
import modules.cache
|
||||
|
||||
dump_cache = modules.cache.dump_cache
|
||||
cache = modules.cache.cache
|
||||
|
||||
|
||||
def calculate_sha256_real(filename):
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
print("Calculating real hash: ", filename)
|
||||
return calculate_sha256_real(filename)
|
||||
|
||||
|
||||
def forge_fake_calculate_sha256(filename):
|
||||
basename = os.path.basename(filename)
|
||||
hash_sha256 = hashlib.sha256()
|
||||
hash_sha256.update(basename.encode('utf-8'))
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
try:
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if title not in hashes:
|
||||
return None
|
||||
|
||||
cached_sha256 = hashes[title].get("sha256", None)
|
||||
cached_mtime = hashes[title].get("mtime", 0)
|
||||
|
||||
if ondisk_mtime > cached_mtime or cached_sha256 is None:
|
||||
return None
|
||||
|
||||
return cached_sha256
|
||||
|
||||
|
||||
def sha256(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
|
||||
sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
|
||||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
if shared.cmd_opts.no_hashing:
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='', flush=True)
|
||||
sha256_value = calculate_sha256_real(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
||||
hashes[title] = {
|
||||
"mtime": os.path.getmtime(filename),
|
||||
"sha256": sha256_value,
|
||||
}
|
||||
|
||||
dump_cache()
|
||||
|
||||
return sha256_value
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
45
modules/hat_model.py
Executable file
45
modules/hat_model.py
Executable file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerHAT(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "HAT"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan) # TODO: should probably be device_hat
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
|
||||
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
|
||||
)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Model file {path} not found")
|
||||
return modelloader.load_spandrel_model(
|
||||
path,
|
||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||
expected_architecture='HAT',
|
||||
)
|
||||
781
modules/hypernetworks/hypernetwork.py
Executable file
781
modules/hypernetworks/hypernetwork.py
Executable file
@@ -0,0 +1,781 @@
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import os
|
||||
import inspect
|
||||
from contextlib import closing
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from backend.nn.unet import default
|
||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = 1.0
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add an activation func except last layer
|
||||
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Everything should be now parsed into dropout structure, and applied here.
|
||||
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
||||
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
||||
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
||||
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
||||
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
self.fix_old_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=0.01)
|
||||
normal_(b, mean=0.0, std=0)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
devices.torch_npu_set_device()
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
|
||||
for fr, to in changes.items():
|
||||
x = state_dict.get(fr, None)
|
||||
if x is None:
|
||||
continue
|
||||
|
||||
del state_dict[fr]
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
layer_structure = [1, 2, 1]
|
||||
if not use_dropout:
|
||||
return [0] * len(layer_structure)
|
||||
dropout_values = [0]
|
||||
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
||||
if last_layer_dropout:
|
||||
dropout_values.append(0.3)
|
||||
else:
|
||||
dropout_values.append(0)
|
||||
dropout_values.append(0)
|
||||
return dropout_values
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
self.step = 0
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.activation_func = activation_func
|
||||
self.weight_init = weight_init
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
||||
self.dropout_structure = kwargs.get('dropout_structure', None)
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
self.optional_info = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
||||
)
|
||||
self.eval()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train(self, mode=True):
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.train(mode=mode)
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for layers in self.layers.values():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def save(self, filename):
|
||||
state_dict = {}
|
||||
optimizer_saved_dict = {}
|
||||
|
||||
for k, v in self.layers.items():
|
||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||
|
||||
state_dict['step'] = self.step
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['dropout_structure'] = self.dropout_structure
|
||||
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
||||
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = self.shorthash()
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def load(self, filename):
|
||||
self.filename = filename
|
||||
if self.name is None:
|
||||
self.name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
self.optional_info = state_dict.get('optional_info', None)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.dropout_structure = state_dict.get('dropout_structure', None)
|
||||
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
||||
if self.dropout_structure is None:
|
||||
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
||||
|
||||
if shared.opts.print_hypernet_extra:
|
||||
if self.optional_info is not None:
|
||||
print(f" INFO:\n {self.optional_info}\n")
|
||||
|
||||
print(f" Layer structure: {self.layer_structure}")
|
||||
print(f" Activation function: {self.activation_func}")
|
||||
print(f" Weight initialization: {self.weight_init}")
|
||||
print(f" Layer norm: {self.add_layer_norm}")
|
||||
print(f" Dropout usage: {self.use_dropout}" )
|
||||
print(f" Activate last layer: {self.activate_output}")
|
||||
print(f" Dropout structure: {self.dropout_structure}")
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
|
||||
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
self.optimizer_name = "AdamW"
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
self.eval()
|
||||
|
||||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10] if sha256 else None
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(name):
|
||||
path = shared.hypernetworks.get(name, None)
|
||||
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
return hypernetwork
|
||||
except Exception:
|
||||
errors.report(f"Error loading hypernetwork {path}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for hypernetwork in shared.loaded_hypernetworks:
|
||||
if hypernetwork.name in names:
|
||||
already_loaded[hypernetwork.name] = hypernetwork
|
||||
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
for i, name in enumerate(names):
|
||||
hypernetwork = already_loaded.get(name, None)
|
||||
if hypernetwork is None:
|
||||
hypernetwork = load_hypernetwork(name)
|
||||
|
||||
if hypernetwork is None:
|
||||
continue
|
||||
|
||||
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context_k, context_v
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
||||
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
context_k = context
|
||||
context_v = context
|
||||
for hypernetwork in hypernetworks:
|
||||
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def stack_conds(conds):
|
||||
if len(conds) == 1:
|
||||
return torch.stack(conds)
|
||||
|
||||
# same as in reconstruct_multicond_batch
|
||||
token_count = max([x.shape[0] for x in conds])
|
||||
for i in range(len(conds)):
|
||||
if conds[i].shape[0] != token_count:
|
||||
last_vector = conds[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
||||
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(conds)
|
||||
|
||||
|
||||
def statistics(data):
|
||||
if len(data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(data)
|
||||
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
||||
recent_data = data[-32:]
|
||||
if len(recent_data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(recent_data)
|
||||
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
||||
return total_information, recent_information
|
||||
|
||||
#
|
||||
# def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
# # Remove illegal characters from name.
|
||||
# name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
# assert name, "Name cannot be empty!"
|
||||
#
|
||||
# fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
# if not overwrite_old:
|
||||
# assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
#
|
||||
# if type(layer_structure) == str:
|
||||
# layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
#
|
||||
# if use_dropout and dropout_structure and type(dropout_structure) == str:
|
||||
# dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
||||
# else:
|
||||
# dropout_structure = [0] * len(layer_structure)
|
||||
#
|
||||
# hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
# name=name,
|
||||
# enable_sizes=[int(x) for x in enable_sizes],
|
||||
# layer_structure=layer_structure,
|
||||
# activation_func=activation_func,
|
||||
# weight_init=weight_init,
|
||||
# add_layer_norm=add_layer_norm,
|
||||
# use_dropout=use_dropout,
|
||||
# dropout_structure=dropout_structure
|
||||
# )
|
||||
# hypernet.save(fn)
|
||||
#
|
||||
# shared.reload_hypernetworks()
|
||||
#
|
||||
#
|
||||
# def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
||||
# from modules import images, processing
|
||||
#
|
||||
# save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
# create_image_every = create_image_every or 0
|
||||
# template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
||||
# textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
# template_file = template_file.path
|
||||
#
|
||||
# path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
# hypernetwork = Hypernetwork()
|
||||
# hypernetwork.load(path)
|
||||
# shared.loaded_hypernetworks = [hypernetwork]
|
||||
#
|
||||
# shared.state.job = "train-hypernetwork"
|
||||
# shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
# shared.state.job_count = steps
|
||||
#
|
||||
# hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
# filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
#
|
||||
# log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||
# unload = shared.opts.unload_models_when_training
|
||||
#
|
||||
# if save_hypernetwork_every > 0:
|
||||
# hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||
# os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||
# else:
|
||||
# hypernetwork_dir = None
|
||||
#
|
||||
# if create_image_every > 0:
|
||||
# images_dir = os.path.join(log_directory, "images")
|
||||
# os.makedirs(images_dir, exist_ok=True)
|
||||
# else:
|
||||
# images_dir = None
|
||||
#
|
||||
# checkpoint = sd_models.select_checkpoint()
|
||||
#
|
||||
# initial_step = hypernetwork.step or 0
|
||||
# if initial_step >= steps:
|
||||
# shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
# return hypernetwork, filename
|
||||
#
|
||||
# scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
#
|
||||
# clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||
# if clip_grad:
|
||||
# clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||
#
|
||||
# if shared.opts.training_enable_tensorboard:
|
||||
# tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
|
||||
#
|
||||
# # dataset loading may take a while, so input validations and early returns should be done before this
|
||||
# shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
#
|
||||
# pin_memory = shared.opts.pin_memory
|
||||
#
|
||||
# ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
|
||||
#
|
||||
# if shared.opts.save_training_settings_to_txt:
|
||||
# saved_params = dict(
|
||||
# model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
|
||||
# **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
|
||||
# )
|
||||
# saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})
|
||||
#
|
||||
# latent_sampling_method = ds.latent_sampling_method
|
||||
#
|
||||
# dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
#
|
||||
# old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
#
|
||||
# if unload:
|
||||
# shared.parallel_processing_allowed = False
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
#
|
||||
# weights = hypernetwork.weights()
|
||||
# hypernetwork.train()
|
||||
#
|
||||
# # Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
# if hypernetwork.optimizer_name in optimizer_dict:
|
||||
# optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
# optimizer_name = hypernetwork.optimizer_name
|
||||
# else:
|
||||
# print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
# optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||
# optimizer_name = 'AdamW'
|
||||
#
|
||||
# if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
# try:
|
||||
# optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
# except RuntimeError as e:
|
||||
# print("Cannot resume from saved optimizer!")
|
||||
# print(e)
|
||||
#
|
||||
# scaler = torch.cuda.amp.GradScaler()
|
||||
#
|
||||
# batch_size = ds.batch_size
|
||||
# gradient_step = ds.gradient_step
|
||||
# # n steps = batch_size * gradient_step * n image processed
|
||||
# steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
# max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
# loss_step = 0
|
||||
# _loss_step = 0 #internal
|
||||
# # size = len(ds.indexes)
|
||||
# # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
# loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size)
|
||||
# # losses = torch.zeros((size,))
|
||||
# # previous_mean_losses = [0]
|
||||
# # previous_mean_loss = 0
|
||||
# # print("Mean loss of {} elements".format(size))
|
||||
#
|
||||
# steps_without_grad = 0
|
||||
#
|
||||
# last_saved_file = "<none>"
|
||||
# last_saved_image = "<none>"
|
||||
# forced_filename = "<none>"
|
||||
#
|
||||
# pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
# try:
|
||||
# sd_hijack_checkpoint.add()
|
||||
#
|
||||
# for _ in range((steps-initial_step) * gradient_step):
|
||||
# if scheduler.finished:
|
||||
# break
|
||||
# if shared.state.interrupted:
|
||||
# break
|
||||
# for j, batch in enumerate(dl):
|
||||
# # works as a drop_last=True for gradient accumulation
|
||||
# if j == max_steps_per_epoch:
|
||||
# break
|
||||
# scheduler.apply(optimizer, hypernetwork.step)
|
||||
# if scheduler.finished:
|
||||
# break
|
||||
# if shared.state.interrupted:
|
||||
# break
|
||||
#
|
||||
# if clip_grad:
|
||||
# clip_grad_sched.step(hypernetwork.step)
|
||||
#
|
||||
# with devices.autocast():
|
||||
# x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
# if use_weight:
|
||||
# w = batch.weight.to(devices.device, non_blocking=pin_memory)
|
||||
# if tag_drop_out != 0 or shuffle_tags:
|
||||
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||
# c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# else:
|
||||
# c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
# if use_weight:
|
||||
# loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
|
||||
# del w
|
||||
# else:
|
||||
# loss = shared.sd_model.forward(x, c)[0] / gradient_step
|
||||
# del x
|
||||
# del c
|
||||
#
|
||||
# _loss_step += loss.item()
|
||||
# scaler.scale(loss).backward()
|
||||
#
|
||||
# # go back until we reach gradient accumulation steps
|
||||
# if (j + 1) % gradient_step != 0:
|
||||
# continue
|
||||
# loss_logging.append(_loss_step)
|
||||
# if clip_grad:
|
||||
# clip_grad(weights, clip_grad_sched.learn_rate)
|
||||
#
|
||||
# scaler.step(optimizer)
|
||||
# scaler.update()
|
||||
# hypernetwork.step += 1
|
||||
# pbar.update()
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
# loss_step = _loss_step
|
||||
# _loss_step = 0
|
||||
#
|
||||
# steps_done = hypernetwork.step + 1
|
||||
#
|
||||
# epoch_num = hypernetwork.step // steps_per_epoch
|
||||
# epoch_step = hypernetwork.step % steps_per_epoch
|
||||
#
|
||||
# description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
|
||||
# pbar.set_description(description)
|
||||
# if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# # Before saving, change name to match current checkpoint.
|
||||
# hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
# last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
# hypernetwork.optimizer_name = optimizer_name
|
||||
# if shared.opts.save_optimizer_state:
|
||||
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
#
|
||||
#
|
||||
#
|
||||
# if shared.opts.training_enable_tensorboard:
|
||||
# epoch_num = hypernetwork.step // len(ds)
|
||||
# epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
|
||||
# mean_loss = sum(loss_logging) / len(loss_logging)
|
||||
# textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
|
||||
#
|
||||
# textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
# "loss": f"{loss_step:.7f}",
|
||||
# "learn_rate": scheduler.learn_rate
|
||||
# })
|
||||
#
|
||||
# if images_dir is not None and steps_done % create_image_every == 0:
|
||||
# forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
# last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
# hypernetwork.eval()
|
||||
# rng_state = torch.get_rng_state()
|
||||
# cuda_rng_state = None
|
||||
# if torch.cuda.is_available():
|
||||
# cuda_rng_state = torch.cuda.get_rng_state_all()
|
||||
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||
# shared.sd_model.first_stage_model.to(devices.device)
|
||||
#
|
||||
# p = processing.StableDiffusionProcessingTxt2Img(
|
||||
# sd_model=shared.sd_model,
|
||||
# do_not_save_grid=True,
|
||||
# do_not_save_samples=True,
|
||||
# )
|
||||
#
|
||||
# p.disable_extra_networks = True
|
||||
#
|
||||
# if preview_from_txt2img:
|
||||
# p.prompt = preview_prompt
|
||||
# p.negative_prompt = preview_negative_prompt
|
||||
# p.steps = preview_steps
|
||||
# p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||
# p.cfg_scale = preview_cfg_scale
|
||||
# p.seed = preview_seed
|
||||
# p.width = preview_width
|
||||
# p.height = preview_height
|
||||
# else:
|
||||
# p.prompt = batch.cond_text[0]
|
||||
# p.steps = 20
|
||||
# p.width = training_width
|
||||
# p.height = training_height
|
||||
#
|
||||
# preview_text = p.prompt
|
||||
#
|
||||
# with closing(p):
|
||||
# processed = processing.process_images(p)
|
||||
# image = processed.images[0] if len(processed.images) > 0 else None
|
||||
#
|
||||
# if unload:
|
||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
# torch.set_rng_state(rng_state)
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.set_rng_state_all(cuda_rng_state)
|
||||
# hypernetwork.train()
|
||||
# if image is not None:
|
||||
# shared.state.assign_current_image(image)
|
||||
# if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
# textual_inversion.tensorboard_add_image(tensorboard_writer,
|
||||
# f"Validation at epoch {epoch_num}", image,
|
||||
# hypernetwork.step)
|
||||
# last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
# last_saved_image += f", prompt: {preview_text}"
|
||||
#
|
||||
# shared.state.job_no = hypernetwork.step
|
||||
#
|
||||
# shared.state.textinfo = f"""
|
||||
# <p>
|
||||
# Loss: {loss_step:.7f}<br/>
|
||||
# Step: {steps_done}<br/>
|
||||
# Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
# Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
# Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
# </p>
|
||||
# """
|
||||
# except Exception:
|
||||
# errors.report("Exception in training hypernetwork", exc_info=True)
|
||||
# finally:
|
||||
# pbar.leave = False
|
||||
# pbar.close()
|
||||
# hypernetwork.eval()
|
||||
# sd_hijack_checkpoint.remove()
|
||||
#
|
||||
#
|
||||
#
|
||||
# filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
# hypernetwork.optimizer_name = optimizer_name
|
||||
# if shared.opts.save_optimizer_state:
|
||||
# hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
# save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
#
|
||||
# del optimizer
|
||||
# hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
# shared.sd_model.cond_stage_model.to(devices.device)
|
||||
# shared.sd_model.first_stage_model.to(devices.device)
|
||||
# shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
#
|
||||
# return hypernetwork, filename
|
||||
#
|
||||
# def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||
# old_hypernetwork_name = hypernetwork.name
|
||||
# old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
# old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
# try:
|
||||
# hypernetwork.sd_checkpoint = checkpoint.shorthash
|
||||
# hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
# hypernetwork.name = hypernetwork_name
|
||||
# hypernetwork.save(filename)
|
||||
# except:
|
||||
# hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||
# hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
# hypernetwork.name = old_hypernetwork_name
|
||||
# raise
|
||||
38
modules/hypernetworks/ui.py
Executable file
38
modules/hypernetworks/ui.py
Executable file
@@ -0,0 +1,38 @@
|
||||
import html
|
||||
|
||||
import gradio as gr
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
shared.loaded_hypernetworks = []
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
|
||||
|
||||
res = f"""
|
||||
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||
Hypernetwork saved to {html.escape(filename)}
|
||||
"""
|
||||
return res, ""
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
898
modules/images.py
Executable file
898
modules/images.py
Executable file
@@ -0,0 +1,898 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import pytz
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
from collections import namedtuple
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
|
||||
# pillow_avif needs to be imported somewhere in code for it to work
|
||||
import pillow_avif # noqa: F401
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks, errors, stealth_infotext
|
||||
from modules.paths_internal import roboto_ttf_file
|
||||
from modules.shared import opts
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
def get_font(fontsize: int):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||
|
||||
|
||||
def image_grid(imgs, batch_size=1, rows=None):
|
||||
if rows is None:
|
||||
if opts.n_rows > 0:
|
||||
rows = opts.n_rows
|
||||
elif opts.n_rows == 0:
|
||||
rows = batch_size
|
||||
elif opts.grid_prevent_empty_spots:
|
||||
rows = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows != 0:
|
||||
rows -= 1
|
||||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
if rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
||||
w, h = map(max, zip(*(img.size for img in imgs)))
|
||||
grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGBA')
|
||||
grid = Image.new('RGBA', size=(params.cols * w, params.rows * h), color=grid_background_color)
|
||||
|
||||
for i, img in enumerate(params.imgs):
|
||||
img_w, img_h = img.size
|
||||
w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2
|
||||
grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset))
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
|
||||
@property
|
||||
def tile_count(self) -> int:
|
||||
"""
|
||||
The total number of tiles in the grid.
|
||||
"""
|
||||
return sum(len(row[2]) for row in self.tiles)
|
||||
|
||||
|
||||
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
||||
w, h = image.size
|
||||
|
||||
non_overlap_width = tile_w - overlap
|
||||
non_overlap_height = tile_h - overlap
|
||||
|
||||
cols = math.ceil((w - overlap) / non_overlap_width)
|
||||
rows = math.ceil((h - overlap) / non_overlap_height)
|
||||
|
||||
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
||||
|
||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||
for row in range(rows):
|
||||
row_images = []
|
||||
|
||||
y = int(row * dy)
|
||||
|
||||
if y + tile_h >= h:
|
||||
y = h - tile_h
|
||||
|
||||
for col in range(cols):
|
||||
x = int(col * dx)
|
||||
|
||||
if x + tile_w >= w:
|
||||
x = w - tile_w
|
||||
|
||||
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||
|
||||
row_images.append([x, tile_w, tile])
|
||||
|
||||
grid.tiles.append([y, tile_h, row_images])
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def combine_grid(grid):
|
||||
def make_mask_image(r):
|
||||
r = r * 255 / grid.overlap
|
||||
r = r.astype(np.uint8)
|
||||
return Image.fromarray(r, 'L')
|
||||
|
||||
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
|
||||
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||
for y, h, row in grid.tiles:
|
||||
combined_row = Image.new("RGB", (grid.image_w, h))
|
||||
for x, w, tile in row:
|
||||
if x == 0:
|
||||
combined_row.paste(tile, (0, 0))
|
||||
continue
|
||||
|
||||
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
||||
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
||||
|
||||
if y == 0:
|
||||
combined_image.paste(combined_row, (0, 0))
|
||||
continue
|
||||
|
||||
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
||||
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
||||
|
||||
return combined_image
|
||||
|
||||
|
||||
class GridAnnotation:
|
||||
def __init__(self, text='', is_active=True):
|
||||
self.text = text
|
||||
self.is_active = is_active
|
||||
self.size = None
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
|
||||
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
|
||||
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
|
||||
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
|
||||
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
line = f'{lines[-1]} {word}'.strip()
|
||||
if drawing.textlength(line, font=font) <= line_length:
|
||||
lines[-1] = line
|
||||
else:
|
||||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
for line in lines:
|
||||
fnt = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
fontsize -= 1
|
||||
fnt = get_font(fontsize)
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||
|
||||
draw_y += line.size[1] + line_spacing
|
||||
|
||||
fontsize = (width + height) // 25
|
||||
line_spacing = fontsize // 2
|
||||
|
||||
fnt = get_font(fontsize)
|
||||
|
||||
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
||||
|
||||
cols = im.width // width
|
||||
rows = im.height // height
|
||||
|
||||
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
||||
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
||||
|
||||
calc_img = Image.new("RGB", (1, 1), color_background)
|
||||
calc_d = ImageDraw.Draw(calc_img)
|
||||
|
||||
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
|
||||
items = [] + texts
|
||||
texts.clear()
|
||||
|
||||
for line in items:
|
||||
wrapped = wrap(calc_d, line.text, fnt, allowed_width)
|
||||
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
||||
|
||||
for line in texts:
|
||||
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
||||
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
line.allowed_width = allowed_width
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
||||
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
||||
|
||||
d = ImageDraw.Draw(result)
|
||||
|
||||
for col in range(cols):
|
||||
x = pad_left + (width + margin) * col + width / 2
|
||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||
|
||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||
|
||||
for row in range(rows):
|
||||
x = pad_left / 2
|
||||
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
||||
|
||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
|
||||
prompts_horiz = prompts[:boundary]
|
||||
prompts_vert = prompts[boundary:]
|
||||
|
||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None, force_RGBA=False):
|
||||
"""
|
||||
Resizes an image with the specified resize_mode, width, and height.
|
||||
|
||||
Args:
|
||||
resize_mode: The mode to use when resizing the image.
|
||||
0: Resize the image to the specified width and height.
|
||||
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
im: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
||||
"""
|
||||
|
||||
if not force_RGBA and im.mode == 'RGBA':
|
||||
im = im.convert('RGB')
|
||||
|
||||
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
||||
|
||||
def resize(im, w, h):
|
||||
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L' or force_RGBA:
|
||||
return im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
scale = max(w / im.width, h / im.height)
|
||||
|
||||
if scale > 1.0:
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
||||
if len(upscalers) == 0:
|
||||
upscaler = shared.sd_upscalers[0]
|
||||
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
|
||||
else:
|
||||
upscaler = upscalers[0]
|
||||
|
||||
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||
|
||||
if im.width != w or im.height != h:
|
||||
im = im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
return im
|
||||
|
||||
if resize_mode == 0:
|
||||
res = resize(im, width, height)
|
||||
|
||||
elif resize_mode == 1:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = width if ratio > src_ratio else im.width * height // im.height
|
||||
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = resize(im, src_w, src_h)
|
||||
res = Image.new("RGB" if not force_RGBA else "RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
else:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = width if ratio < src_ratio else im.width * height // im.height
|
||||
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = resize(im, src_w, src_h)
|
||||
res = Image.new("RGB" if not force_RGBA else "RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
if fill_height > 0:
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
if fill_width > 0:
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
if not shared.cmd_opts.unix_filenames_sanitization:
|
||||
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||
else:
|
||||
invalid_filename_chars = '/'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = shared.cmd_opts.filenames_max_length
|
||||
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||
|
||||
|
||||
def sanitize_filename_part(text, replace_spaces=True):
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
if replace_spaces:
|
||||
text = text.replace(' ', '_')
|
||||
|
||||
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
|
||||
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
|
||||
text = text.rstrip(invalid_filename_postfix)
|
||||
return text
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_scheduler_str(sampler_name, scheduler_name):
|
||||
"""Returns {Scheduler} if the scheduler is applicable to the sampler"""
|
||||
if scheduler_name == 'Automatic':
|
||||
config = sd_samplers.find_sampler_config(sampler_name)
|
||||
scheduler_name = config.options.get('scheduler', 'Automatic')
|
||||
return scheduler_name.capitalize()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_sampler_scheduler_str(sampler_name, scheduler_name):
|
||||
"""Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler"""
|
||||
return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}'
|
||||
|
||||
|
||||
def get_sampler_scheduler(p, sampler):
|
||||
"""Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'"""
|
||||
if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'):
|
||||
if sampler:
|
||||
sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler)
|
||||
else:
|
||||
sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler)
|
||||
return sanitize_filename_part(sampler_scheduler, replace_spaces=False)
|
||||
return NOTHING_AND_SKIP_PREVIOUS_TEXT
|
||||
|
||||
|
||||
class FilenameGenerator:
|
||||
replacements = {
|
||||
'basename': lambda self: self.basename or 'img',
|
||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||
'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
|
||||
'steps': lambda self: self.p and self.p.steps,
|
||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||
'width': lambda self: self.image.width,
|
||||
'height': lambda self: self.image.height,
|
||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
||||
'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True),
|
||||
'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False),
|
||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
|
||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
|
||||
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
|
||||
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
'prompt_words': lambda self: self.prompt_words(),
|
||||
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
|
||||
'batch_size': lambda self: self.p.batch_size,
|
||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||
'user': lambda self: self.p.user,
|
||||
'vae_filename': lambda self: self.get_vae_filename(),
|
||||
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
|
||||
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
def __init__(self, p, seed, prompt, image, zip=False, basename=""):
|
||||
self.p = p
|
||||
self.seed = seed
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
self.zip = zip
|
||||
self.basename = basename
|
||||
|
||||
def get_vae_filename(self):
|
||||
"""Get the name of the VAE file."""
|
||||
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
if sd_vae.loaded_vae_file is None:
|
||||
return "NoneType"
|
||||
|
||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||
split_file_name = file_name.split('.')
|
||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||
else:
|
||||
return split_file_name[0]
|
||||
|
||||
|
||||
def hasprompt(self, *args):
|
||||
lower = self.prompt.lower()
|
||||
if self.p is None or self.prompt is None:
|
||||
return None
|
||||
outres = ""
|
||||
for arg in args:
|
||||
if arg != "":
|
||||
division = arg.split("|")
|
||||
expected = division[0].lower()
|
||||
default = division[1] if len(division) > 1 else ""
|
||||
if lower.find(expected) >= 0:
|
||||
outres = f'{outres}{expected}'
|
||||
else:
|
||||
outres = outres if default == "" else f'{outres}{default}'
|
||||
return sanitize_filename_part(outres)
|
||||
|
||||
def prompt_no_style(self):
|
||||
if self.p is None or self.prompt is None:
|
||||
return None
|
||||
|
||||
prompt_no_style = self.prompt
|
||||
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
||||
if style:
|
||||
for part in style.split("{prompt}"):
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
|
||||
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
||||
|
||||
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
||||
|
||||
def prompt_words(self):
|
||||
words = [x for x in re_nonletters.split(self.prompt or "") if x]
|
||||
if len(words) == 0:
|
||||
words = ["empty"]
|
||||
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
||||
|
||||
def datetime(self, *args):
|
||||
time_datetime = datetime.datetime.now()
|
||||
|
||||
time_format = args[0] if (args and args[0] != "") else self.default_time_format
|
||||
try:
|
||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||
except pytz.exceptions.UnknownTimeZoneError:
|
||||
time_zone = None
|
||||
|
||||
time_zone_time = time_datetime.astimezone(time_zone)
|
||||
try:
|
||||
formatted_time = time_zone_time.strftime(time_format)
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||
|
||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||
|
||||
def image_hash(self, *args):
|
||||
length = int(args[0]) if (args and args[0] != "") else None
|
||||
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
|
||||
|
||||
def string_hash(self, text, *args):
|
||||
length = int(args[0]) if (args and args[0] != "") else 8
|
||||
return hashlib.sha256(text.encode()).hexdigest()[0:length]
|
||||
|
||||
def apply(self, x):
|
||||
res = ''
|
||||
|
||||
for m in re_pattern.finditer(x):
|
||||
text, pattern = m.groups()
|
||||
|
||||
if pattern is None:
|
||||
res += text
|
||||
continue
|
||||
|
||||
pattern_args = []
|
||||
while True:
|
||||
m = re_pattern_arg.match(pattern)
|
||||
if m is None:
|
||||
break
|
||||
|
||||
pattern, arg = m.groups()
|
||||
pattern_args.insert(0, arg)
|
||||
|
||||
fun = self.replacements.get(pattern.lower())
|
||||
if fun is not None:
|
||||
try:
|
||||
replacement = fun(self, *pattern_args)
|
||||
except Exception:
|
||||
replacement = None
|
||||
errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
|
||||
|
||||
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||
continue
|
||||
elif replacement is not None:
|
||||
res += text + str(replacement)
|
||||
continue
|
||||
|
||||
res += f'{text}[{pattern}]'
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_next_sequence_number(path, basename):
|
||||
"""
|
||||
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
||||
|
||||
The sequence starts at 0.
|
||||
"""
|
||||
result = -1
|
||||
if basename != '':
|
||||
basename = f"{basename}-"
|
||||
|
||||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(parts[0]), result)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return result + 1
|
||||
|
||||
|
||||
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
|
||||
"""
|
||||
Saves image to filename, including geninfo as text information for generation info.
|
||||
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
|
||||
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
|
||||
"""
|
||||
|
||||
if extension is None:
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
if extension.lower() == '.png':
|
||||
existing_pnginfo = existing_pnginfo or {}
|
||||
if opts.enable_pnginfo:
|
||||
existing_pnginfo[pnginfo_section_name] = geninfo
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in (existing_pnginfo or {}).items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
else:
|
||||
pnginfo_data = None
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert("RGB")
|
||||
elif image.mode == 'I;16':
|
||||
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||
|
||||
if opts.enable_pnginfo and geninfo is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, filename)
|
||||
elif extension.lower() == '.avif':
|
||||
if opts.enable_pnginfo and geninfo is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
else:
|
||||
exif_bytes = None
|
||||
|
||||
image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes)
|
||||
elif extension.lower() == ".gif":
|
||||
image.save(filename, format=image_format, comment=geninfo)
|
||||
else:
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||
"""Save an image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image`):
|
||||
The image to be saved.
|
||||
path (`str`):
|
||||
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
||||
basename (`str`):
|
||||
The base filename which will be applied to `filename pattern`.
|
||||
seed, prompt, short_filename,
|
||||
extension (`str`):
|
||||
Image file extension, default is `png`.
|
||||
pngsectionname (`str`):
|
||||
Specify the name of the section which `info` will be saved in.
|
||||
info (`str` or `PngImagePlugin.iTXt`):
|
||||
PNG info chunks.
|
||||
existing_info (`dict`):
|
||||
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
|
||||
no_prompt:
|
||||
TODO I don't know its meaning.
|
||||
p (`StableDiffusionProcessing` or `Processing`)
|
||||
forced_filename (`str`):
|
||||
If specified, `basename` and filename pattern will be ignored.
|
||||
save_to_dirs (bool):
|
||||
If true, the image will be saved into a subdirectory of `path`.
|
||||
|
||||
Returns: (fullfn, txt_fullfn)
|
||||
fullfn (`str`):
|
||||
The full path of the saved imaged.
|
||||
txt_fullfn (`str` or None):
|
||||
If a text file is saved for this image, this will be its full path. Otherwise None.
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt, image, basename=basename)
|
||||
|
||||
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
|
||||
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
|
||||
print('Image dimensions too large; saving as PNG')
|
||||
extension = "png"
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
if save_to_dirs:
|
||||
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
||||
path = os.path.join(path, dirname)
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
if forced_filename is None:
|
||||
if short_filename or seed is None:
|
||||
file_decoration = ""
|
||||
elif hasattr(p, 'override_settings'):
|
||||
file_decoration = p.override_settings.get("samples_filename_pattern")
|
||||
else:
|
||||
file_decoration = None
|
||||
|
||||
if file_decoration is None:
|
||||
file_decoration = opts.samples_filename_pattern or ("[seed]" if opts.save_to_dirs else "[seed]-[prompt_spaces]")
|
||||
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
add_number = opts.save_images_add_number or file_decoration == ''
|
||||
|
||||
if file_decoration != "" and add_number:
|
||||
file_decoration = f"-{file_decoration}"
|
||||
|
||||
if add_number:
|
||||
basecount = get_next_sequence_number(path, basename)
|
||||
fullfn = None
|
||||
for i in range(500):
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
if not os.path.exists(fullfn):
|
||||
break
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
||||
|
||||
pnginfo = existing_info or {}
|
||||
if info is not None:
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
||||
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
||||
if opts.enable_pnginfo:
|
||||
stealth_infotext.add_stealth_pnginfo(params)
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
|
||||
image = params.image
|
||||
fullfn = params.filename
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
|
||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||
"""
|
||||
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
"""
|
||||
temp_file_path = f"{filename_without_extension}.tmp"
|
||||
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||
|
||||
filename = filename_without_extension + extension
|
||||
without_extension = filename_without_extension
|
||||
if shared.opts.save_images_replace_action != "Replace":
|
||||
n = 0
|
||||
while os.path.exists(filename):
|
||||
n += 1
|
||||
without_extension = f"{filename_without_extension}-{n}"
|
||||
filename = without_extension + extension
|
||||
os.replace(temp_file_path, filename)
|
||||
return without_extension
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
if hasattr(os, 'statvfs'):
|
||||
max_name_len = os.statvfs(path).f_namemax
|
||||
fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]
|
||||
params.filename = fullfn_without_extension + extension
|
||||
fullfn = params.filename
|
||||
|
||||
fullfn_without_extension = _atomically_save_image(image, fullfn_without_extension, extension)
|
||||
fullfn = fullfn_without_extension + extension
|
||||
image.already_saved_as = fullfn
|
||||
|
||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
||||
ratio = image.width / image.height
|
||||
resize_to = None
|
||||
if oversize and ratio > 1:
|
||||
resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
|
||||
elif oversize:
|
||||
resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
|
||||
|
||||
if resize_to is not None:
|
||||
try:
|
||||
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
|
||||
image = image.resize(resize_to, LANCZOS)
|
||||
except Exception:
|
||||
image = image.resize(resize_to)
|
||||
try:
|
||||
_ = _atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
except Exception as e:
|
||||
errors.display(e, "saving image as downscaled JPG")
|
||||
|
||||
if opts.save_txt and info is not None:
|
||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||
file.write(f"{info}\n")
|
||||
else:
|
||||
txt_fullfn = None
|
||||
|
||||
script_callbacks.image_saved_callback(params)
|
||||
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
IGNORED_INFO_KEYS = {
|
||||
'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
|
||||
'icc_profile', 'chromaticity', 'photoshop',
|
||||
}
|
||||
|
||||
|
||||
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
"""Read generation info from an image, checking standard metadata first, then stealth info if needed."""
|
||||
|
||||
def read_standard():
|
||||
items = (image.info or {}).copy()
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
if "exif" in items:
|
||||
exif_data = items["exif"]
|
||||
try:
|
||||
exif = piexif.load(exif_data)
|
||||
except OSError:
|
||||
# memory / exif was not valid so piexif tried to read from a file
|
||||
exif = None
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
if exif_comment:
|
||||
geninfo = exif_comment
|
||||
elif "comment" in items: # for gif
|
||||
if isinstance(items["comment"], bytes):
|
||||
geninfo = items["comment"].decode('utf8', errors="ignore")
|
||||
else:
|
||||
geninfo = items["comment"]
|
||||
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
json_info = json.loads(items["Comment"])
|
||||
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
||||
|
||||
geninfo = f"""{items["Description"]}
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
geninfo, items = read_standard()
|
||||
if geninfo is None:
|
||||
geninfo = stealth_infotext.read_info_from_image_stealth(image)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
|
||||
def image_data(data):
|
||||
import gradio as gr
|
||||
|
||||
try:
|
||||
image = read(io.BytesIO(data))
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
return textinfo, None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
text = data.decode('utf8')
|
||||
assert len(text) < 10000
|
||||
return text, None
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return gr.update(), None
|
||||
|
||||
|
||||
def flatten(img, bgcolor):
|
||||
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
||||
|
||||
if img.mode == "RGBA":
|
||||
background = Image.new('RGBA', img.size, bgcolor)
|
||||
background.paste(img, mask=img)
|
||||
img = background
|
||||
|
||||
return img.convert('RGB')
|
||||
|
||||
|
||||
def read(fp, **kwargs):
|
||||
image = Image.open(fp, **kwargs)
|
||||
image = fix_image(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def fix_image(image: Image.Image):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
image = fix_png_transparency(image)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def fix_png_transparency(image: Image.Image):
|
||||
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
|
||||
return image
|
||||
|
||||
image = image.convert("RGBA")
|
||||
return image
|
||||
263
modules/img2img.py
Executable file
263
modules/img2img.py
Executable file
@@ -0,0 +1,263 @@
|
||||
import os
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import images
|
||||
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.sd_models import get_closet_checkpoint_match
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.scripts
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
output_dir = output_dir.strip()
|
||||
processing.fix_seed(p)
|
||||
|
||||
if isinstance(input, str):
|
||||
batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff", ".avif")))
|
||||
else:
|
||||
batch_images = [os.path.abspath(x.name) for x in input]
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
inpaint_masks = shared.listfiles(inpaint_mask_dir)
|
||||
is_inpaint_batch = bool(inpaint_masks)
|
||||
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
state.job_count = len(batch_images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
prompt = p.prompt
|
||||
negative_prompt = p.negative_prompt
|
||||
seed = p.seed
|
||||
cfg_scale = p.cfg_scale
|
||||
sampler_name = p.sampler_name
|
||||
steps = p.steps
|
||||
override_settings = p.override_settings
|
||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||
batch_results = None
|
||||
discard_further_results = False
|
||||
for i, image in enumerate(batch_images):
|
||||
state.job = f"{i+1} out of {len(batch_images)}"
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted or state.stopping_generation:
|
||||
break
|
||||
|
||||
try:
|
||||
img = images.read(image)
|
||||
except UnidentifiedImageError as e:
|
||||
print(e)
|
||||
continue
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
|
||||
if to_scale:
|
||||
p.width = int(img.width * scale_by)
|
||||
p.height = int(img.height * scale_by)
|
||||
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
image_path = Path(image)
|
||||
if is_inpaint_batch:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
if len(inpaint_masks) == 1:
|
||||
mask_image_path = inpaint_masks[0]
|
||||
else:
|
||||
# try to find corresponding mask for an image using simple filename matching
|
||||
mask_image_dir = Path(inpaint_mask_dir)
|
||||
masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
|
||||
|
||||
if len(masks_found) == 0:
|
||||
print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
|
||||
continue
|
||||
|
||||
# it should contain only 1 matching mask
|
||||
# otherwise user has many masks with the same name but different extensions
|
||||
mask_image_path = masks_found[0]
|
||||
|
||||
mask_image = images.read(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
if use_png_info:
|
||||
try:
|
||||
info_img = img
|
||||
if png_info_dir:
|
||||
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||
info_img = images.read(info_img_path)
|
||||
geninfo, _ = images.read_info_from_image(info_img)
|
||||
parsed_parameters = parse_generation_parameters(geninfo)
|
||||
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||
except Exception:
|
||||
parsed_parameters = {}
|
||||
|
||||
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
|
||||
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
|
||||
p.seed = int(parsed_parameters.get("Seed", seed))
|
||||
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
|
||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||
|
||||
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
|
||||
if model_info is not None:
|
||||
p.override_settings['sd_model_checkpoint'] = model_info.name
|
||||
elif sd_model_checkpoint_override:
|
||||
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
|
||||
else:
|
||||
p.override_settings.pop("sd_model_checkpoint", None)
|
||||
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
|
||||
if opts.img2img_batch_use_original_name:
|
||||
filename_pattern = f'{image_path.stem}-[generation_number]' if p.n_iter > 1 or p.batch_size > 1 else f'{image_path.stem}'
|
||||
p.override_settings['samples_filename_pattern'] = filename_pattern
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
||||
if not discard_further_results and proc:
|
||||
if batch_results:
|
||||
batch_results.images.extend(proc.images)
|
||||
batch_results.infotexts.extend(proc.infotexts)
|
||||
else:
|
||||
batch_results = proc
|
||||
|
||||
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
|
||||
discard_further_results = True
|
||||
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||
|
||||
return batch_results
|
||||
|
||||
|
||||
def img2img_function(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):
|
||||
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
|
||||
height, width = int(height), int(width)
|
||||
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
if mode == 0: # img2img
|
||||
image = init_img
|
||||
mask = None
|
||||
elif mode == 1: # img2img sketch
|
||||
mask = None
|
||||
image = Image.alpha_composite(sketch, sketch_fg)
|
||||
elif mode == 2: # inpaint
|
||||
image = init_img_with_mask
|
||||
mask = init_img_with_mask_fg.getchannel('A').convert('L')
|
||||
mask = Image.merge('RGBA', (mask, mask, mask, Image.new('L', mask.size, 255)))
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = Image.alpha_composite(inpaint_color_sketch, inpaint_color_sketch_fg)
|
||||
mask = inpaint_color_sketch_fg.getchannel('A').convert('L')
|
||||
short_side = min(mask.size)
|
||||
dilation_size = int(0.015 * short_side) * 2 + 1
|
||||
mask = mask.filter(ImageFilter.MaxFilter(dilation_size))
|
||||
mask = Image.merge('RGBA', (mask, mask, mask, Image.new('L', mask.size, 255)))
|
||||
elif mode == 4: # inpaint upload mask
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
|
||||
if mask and isinstance(mask, Image.Image):
|
||||
mask = mask.point(lambda v: 255 if v > 128 else 0)
|
||||
|
||||
image = images.fix_image(image)
|
||||
mask = images.fix_image(mask)
|
||||
|
||||
if selected_scale_tab == 1 and not is_batch:
|
||||
assert image, "Can't scale by because no image is selected"
|
||||
|
||||
width = int(image.width * scale_by)
|
||||
width -= width % 8
|
||||
height = int(image.height * scale_by)
|
||||
height -= height % 8
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
|
||||
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
styles=prompt_styles,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
init_images=[image],
|
||||
mask=mask,
|
||||
mask_blur=mask_blur,
|
||||
inpainting_fill=inpainting_fill,
|
||||
resize_mode=resize_mode,
|
||||
denoising_strength=denoising_strength,
|
||||
image_cfg_scale=image_cfg_scale,
|
||||
inpaint_full_res=inpaint_full_res,
|
||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
override_settings=override_settings,
|
||||
distilled_cfg_scale=distilled_cfg_scale
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_img2img
|
||||
p.script_args = args
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
with closing(p):
|
||||
if is_batch:
|
||||
if img2img_batch_source_type == "upload":
|
||||
assert isinstance(img2img_batch_upload, list) and img2img_batch_upload
|
||||
output_dir = ""
|
||||
inpaint_mask_dir = ""
|
||||
png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else ""
|
||||
processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir)
|
||||
else: # "from dir"
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||
|
||||
if processed is None:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
else:
|
||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
if opts.samples_log_stdout:
|
||||
print(generation_info_js)
|
||||
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):
|
||||
return main_thread.run_and_wait_result(img2img_function, id_task, request, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, sketch_fg, init_img_with_mask, init_img_with_mask_fg, inpaint_color_sketch, inpaint_color_sketch_fg, init_img_inpaint, init_mask_inpaint, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, distilled_cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_source_type, img2img_batch_upload, *args)
|
||||
16
modules/import_hook.py
Executable file
16
modules/import_hook.py
Executable file
@@ -0,0 +1,16 @@
|
||||
# import sys
|
||||
#
|
||||
# # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||
# if "--xformers" not in "".join(sys.argv):
|
||||
# sys.modules["xformers"] = None
|
||||
#
|
||||
# # Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
||||
# # basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
|
||||
# try:
|
||||
# import torchvision.transforms.functional_tensor # noqa: F401
|
||||
# except ImportError:
|
||||
# try:
|
||||
# import torchvision.transforms.functional as functional
|
||||
# sys.modules["torchvision.transforms.functional_tensor"] = functional
|
||||
# except ImportError:
|
||||
# pass # shrug...
|
||||
643
modules/infotext_utils.py
Executable file
643
modules/infotext_utils.py
Executable file
@@ -0,0 +1,643 @@
|
||||
from __future__ import annotations
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors
|
||||
from PIL import Image
|
||||
|
||||
from modules_forge import main_entry
|
||||
|
||||
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||
self.paste_button = paste_button
|
||||
self.tabname = tabname
|
||||
self.source_text_component = source_text_component
|
||||
self.source_image_component = source_image_component
|
||||
self.source_tabname = source_tabname
|
||||
self.override_settings_component = override_settings_component
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
class PasteField(tuple):
|
||||
def __new__(cls, component, target, *, api=None):
|
||||
return super().__new__(cls, (component, target))
|
||||
|
||||
def __init__(self, component, target, *, api=None):
|
||||
super().__init__()
|
||||
|
||||
self.api = api
|
||||
self.component = component
|
||||
self.label = target if isinstance(target, str) else None
|
||||
self.function = target if callable(target) else None
|
||||
|
||||
|
||||
paste_fields: dict[str, dict] = {}
|
||||
registered_param_bindings: list[ParamBinding] = []
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
registered_param_bindings.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
|
||||
return text
|
||||
|
||||
return json.dumps(text, ensure_ascii=False)
|
||||
|
||||
|
||||
def unquote(text):
|
||||
if len(text) == 0 or text[0] != '"' or text[-1] != '"':
|
||||
return text
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except Exception:
|
||||
return text
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if filedata is None:
|
||||
return None
|
||||
|
||||
if isinstance(filedata, list):
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if isinstance(filedata, dict) and filedata.get("is_file", False):
|
||||
filedata = filedata
|
||||
|
||||
filename = None
|
||||
if type(filedata) == dict and filedata.get("is_file", False):
|
||||
filename = filedata["name"]
|
||||
|
||||
elif isinstance(filedata, tuple) and len(filedata) == 2: # gradio 4.16 sends images from gallery as a list of tuples
|
||||
return filedata[0]
|
||||
|
||||
if filename:
|
||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
filename = filename.rsplit('?', 1)[0]
|
||||
return images.read(filename)
|
||||
|
||||
if isinstance(filedata, str):
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = images.read(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
||||
|
||||
if fields:
|
||||
for i in range(len(fields)):
|
||||
if not isinstance(fields[i], PasteField):
|
||||
fields[i] = PasteField(*fields[i])
|
||||
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
import modules.ui
|
||||
if tabname == 'txt2img':
|
||||
modules.ui.txt2img_paste_fields = fields
|
||||
elif tabname == 'img2img':
|
||||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
||||
return buttons
|
||||
|
||||
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
"""old function for backwards compatibility; do not use this, use register_paste_params_button"""
|
||||
for tabname, button in buttons.items():
|
||||
source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
|
||||
source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
|
||||
|
||||
register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
|
||||
|
||||
|
||||
def register_paste_params_button(binding: ParamBinding):
|
||||
registered_param_bindings.append(binding)
|
||||
|
||||
|
||||
def connect_paste_params_buttons():
|
||||
for binding in registered_param_bindings:
|
||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||
fields = paste_fields[binding.tabname]["fields"]
|
||||
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if binding.source_image_component and destination_image_component:
|
||||
need_send_dementions = destination_width_component and binding.tabname != 'inpaint'
|
||||
if isinstance(binding.source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if need_send_dementions else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if need_send_dementions else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[binding.source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
if binding.source_text_component is not None and fields is not None:
|
||||
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
||||
|
||||
if binding.source_tabname is not None and fields is not None:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
|
||||
binding.paste_button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
binding.paste_button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{binding.tabname}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
if isinstance(x, Image.Image):
|
||||
img = x
|
||||
if img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
elif isinstance(x, list) and isinstance(x[0], tuple):
|
||||
img = x[0][0]
|
||||
else:
|
||||
img = image_from_url_text(x)
|
||||
if img is not None and img.mode == 'RGBA':
|
||||
img = img.convert('RGB')
|
||||
|
||||
if shared.opts.send_size and isinstance(img, Image.Image):
|
||||
w = img.width
|
||||
h = img.height
|
||||
else:
|
||||
w = gr.update()
|
||||
h = gr.update()
|
||||
|
||||
return img, w, h
|
||||
|
||||
|
||||
def restore_old_hires_fix_params(res):
|
||||
"""for infotexts that specify old First pass size parameter, convert it into
|
||||
width, height, and hr scale"""
|
||||
|
||||
firstpass_width = res.get('First pass size-1', None)
|
||||
firstpass_height = res.get('First pass size-2', None)
|
||||
|
||||
if shared.opts.use_old_hires_fix_width_height:
|
||||
hires_width = int(res.get("Hires resize-1", 0))
|
||||
hires_height = int(res.get("Hires resize-2", 0))
|
||||
|
||||
if hires_width and hires_height:
|
||||
res['Size-1'] = hires_width
|
||||
res['Size-2'] = hires_height
|
||||
return
|
||||
|
||||
if firstpass_width is None or firstpass_height is None:
|
||||
return
|
||||
|
||||
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
||||
width = int(res.get("Size-1", 512))
|
||||
height = int(res.get("Size-2", 512))
|
||||
|
||||
if firstpass_width == 0 or firstpass_height == 0:
|
||||
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
res['Size-2'] = firstpass_height
|
||||
res['Hires resize-1'] = width
|
||||
res['Hires resize-2'] = height
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
||||
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
|
||||
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
|
||||
```
|
||||
|
||||
returns a dict with field values
|
||||
"""
|
||||
if skip_fields is None:
|
||||
skip_fields = shared.opts.infotext_skip_pasting
|
||||
|
||||
res = {}
|
||||
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
|
||||
done_with_prompt = False
|
||||
|
||||
*lines, lastline = x.strip().split("\n")
|
||||
if len(re_param.findall(lastline)) < 3:
|
||||
lines.append(lastline)
|
||||
lastline = ''
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("Negative prompt:"):
|
||||
done_with_prompt = True
|
||||
line = line[16:].strip()
|
||||
if done_with_prompt:
|
||||
negative_prompt += ("" if negative_prompt == "" else "\n") + line
|
||||
else:
|
||||
prompt += ("" if prompt == "" else "\n") + line
|
||||
|
||||
if 'Civitai' in lastline and 'FLUX' in lastline:
|
||||
# Civitai really like to add random Clip skip to Flux metadata, where Clip skip is not a thing.
|
||||
lastline = lastline.replace('Clip skip: 0, ', '')
|
||||
lastline = lastline.replace('Clip skip: 1, ', '')
|
||||
lastline = lastline.replace('Clip skip: 2, ', '')
|
||||
lastline = lastline.replace('Clip skip: 3, ', '')
|
||||
lastline = lastline.replace('Clip skip: 4, ', '')
|
||||
lastline = lastline.replace('Clip skip: 5, ', '')
|
||||
lastline = lastline.replace('Clip skip: 6, ', '')
|
||||
lastline = lastline.replace('Clip skip: 7, ', '')
|
||||
lastline = lastline.replace('Clip skip: 8, ', '')
|
||||
|
||||
# Civitai also add Sampler: Undefined
|
||||
lastline = lastline.replace('Sampler: Undefined, ', 'Sampler: Euler, Schedule type: Simple, ') # <- by lllyasviel, seem to give similar results to Civitai "Undefined" Sampler
|
||||
|
||||
# Civitai also confuse CFG scale and Distilled CFG Scale
|
||||
lastline = lastline.replace('CFG scale: ', 'CFG scale: 1, Distilled CFG Scale: ')
|
||||
|
||||
print('Applied Forge Fix to broken Civitai Flux Meta.')
|
||||
|
||||
for k, v in re_param.findall(lastline):
|
||||
try:
|
||||
if v[0] == '"' and v[-1] == '"':
|
||||
v = unquote(v)
|
||||
|
||||
m = re_imagesize.match(v)
|
||||
if m is not None:
|
||||
res[f"{k}-1"] = m.group(1)
|
||||
res[f"{k}-2"] = m.group(2)
|
||||
else:
|
||||
res[k] = v
|
||||
except Exception:
|
||||
print(f"Error parsing \"{k}: {v}\"")
|
||||
|
||||
# Extract styles from prompt
|
||||
if shared.opts.infotext_styles != "Ignore":
|
||||
found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
||||
|
||||
same_hr_styles = True
|
||||
if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
|
||||
hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
|
||||
hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
|
||||
if same_hr_styles := found_styles == hr_found_styles:
|
||||
res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
|
||||
res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
|
||||
|
||||
if same_hr_styles:
|
||||
prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
|
||||
if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
|
||||
res['Styles array'] = found_styles
|
||||
|
||||
res["Prompt"] = prompt
|
||||
res["Negative prompt"] = negative_prompt
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
hypernet = res.get("Hypernet", None)
|
||||
if hypernet is not None:
|
||||
res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
|
||||
|
||||
if "Hires resize-1" not in res:
|
||||
res["Hires resize-1"] = 0
|
||||
res["Hires resize-2"] = 0
|
||||
|
||||
if "Hires sampler" not in res:
|
||||
res["Hires sampler"] = "Use same sampler"
|
||||
|
||||
if "Hires schedule type" not in res:
|
||||
res["Hires schedule type"] = "Use same scheduler"
|
||||
|
||||
if "Hires checkpoint" not in res:
|
||||
res["Hires checkpoint"] = "Use same checkpoint"
|
||||
|
||||
if "Hires prompt" not in res:
|
||||
res["Hires prompt"] = ""
|
||||
|
||||
if "Hires negative prompt" not in res:
|
||||
res["Hires negative prompt"] = ""
|
||||
|
||||
if "Mask mode" not in res:
|
||||
res["Mask mode"] = "Inpaint masked"
|
||||
|
||||
if "Masked content" not in res:
|
||||
res["Masked content"] = 'original'
|
||||
|
||||
if "Inpaint area" not in res:
|
||||
res["Inpaint area"] = "Whole picture"
|
||||
|
||||
if "Masked area padding" not in res:
|
||||
res["Masked area padding"] = 32
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
if "RNG" not in res:
|
||||
res["RNG"] = "GPU"
|
||||
|
||||
if "Schedule type" not in res:
|
||||
res["Schedule type"] = "Automatic"
|
||||
|
||||
if "Schedule max sigma" not in res:
|
||||
res["Schedule max sigma"] = 0
|
||||
|
||||
if "Schedule min sigma" not in res:
|
||||
res["Schedule min sigma"] = 0
|
||||
|
||||
if "Schedule rho" not in res:
|
||||
res["Schedule rho"] = 0
|
||||
|
||||
if "VAE Encoder" not in res:
|
||||
res["VAE Encoder"] = "Full"
|
||||
|
||||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
if "FP8 weight" not in res:
|
||||
res["FP8 weight"] = "Disable"
|
||||
|
||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
prompt_attention = prompt_parser.parse_prompt_attention(prompt)
|
||||
prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
|
||||
prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
|
||||
if "Emphasis" not in res and prompt_uses_emphasis:
|
||||
res["Emphasis"] = "Original"
|
||||
|
||||
if "Refiner switch by sampling steps" not in res:
|
||||
res["Refiner switch by sampling steps"] = False
|
||||
|
||||
infotext_versions.backcompat(res)
|
||||
|
||||
for key in skip_fields:
|
||||
res.pop(key, None)
|
||||
|
||||
# basic check for same checkpoint using short name
|
||||
checkpoint = res.get('Model', None)
|
||||
if checkpoint is not None:
|
||||
if checkpoint in shared.opts.sd_model_checkpoint:
|
||||
res.pop('Model')
|
||||
|
||||
# VAE / TE
|
||||
modules = []
|
||||
hr_modules = []
|
||||
vae = res.pop('VAE', None) # old form
|
||||
if vae:
|
||||
modules = [vae]
|
||||
else:
|
||||
for key in res:
|
||||
if key.startswith('Module '):
|
||||
added = False
|
||||
for knownmodule in main_entry.module_list.keys():
|
||||
filename, _ = os.path.splitext(knownmodule)
|
||||
if res[key] == filename:
|
||||
added = True
|
||||
modules.append(knownmodule)
|
||||
break
|
||||
if not added:
|
||||
modules.append(res[key]) # so it shows in the override section (consistent with checkpoint and old vae)
|
||||
elif key.startswith('Hires Module '):
|
||||
for knownmodule in main_entry.module_list.keys():
|
||||
filename, _ = os.path.splitext(knownmodule)
|
||||
if res[key] == filename:
|
||||
hr_modules.append(knownmodule)
|
||||
break
|
||||
|
||||
if modules != []:
|
||||
current_modules = shared.opts.forge_additional_modules
|
||||
basename_modules = []
|
||||
for m in current_modules:
|
||||
basename_modules.append(os.path.basename(m))
|
||||
|
||||
if sorted(modules) != sorted(basename_modules):
|
||||
res['VAE/TE'] = modules
|
||||
|
||||
# if 'Use same choices' was the selection for Hires VAE / Text Encoder, it will be the only Hires Module
|
||||
# if the selection was empty, it will be the only Hires Module, saved as 'Built-in'
|
||||
if 'Hires Module 1' in res:
|
||||
if res['Hires Module 1'] == 'Use same choices':
|
||||
hr_modules = ['Use same choices']
|
||||
elif res['Hires Module 1'] == 'Built-in':
|
||||
hr_modules = []
|
||||
|
||||
res['Hires VAE/TE'] = hr_modules
|
||||
else:
|
||||
# no Hires Module infotext, use default
|
||||
res['Hires VAE/TE'] = ['Use same choices']
|
||||
|
||||
return res
|
||||
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('VAE/TE', 'forge_additional_modules'),
|
||||
]
|
||||
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
|
||||
Example content:
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
('Model hash', 'sd_model_checkpoint'),
|
||||
('ENSD', 'eta_noise_seed_delta'),
|
||||
('Schedule type', 'k_sched_type'),
|
||||
]
|
||||
"""
|
||||
from ast import literal_eval
|
||||
def create_override_settings_dict(text_pairs):
|
||||
"""creates processing's override_settings parameters from gradio's multiselect
|
||||
|
||||
Example input:
|
||||
['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
|
||||
|
||||
Example output:
|
||||
{'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
|
||||
"""
|
||||
|
||||
res = {}
|
||||
|
||||
if not text_pairs:
|
||||
return res
|
||||
|
||||
params = {}
|
||||
for pair in text_pairs:
|
||||
k, v = pair.split(":", maxsplit=1)
|
||||
|
||||
params[k] = v.strip()
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
value = params.get(param_name, None)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if setting_name == "forge_additional_modules":
|
||||
res[setting_name] = literal_eval(value)
|
||||
continue
|
||||
|
||||
res[setting_name] = shared.opts.cast_value(setting_name, value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_override_settings(params, *, skip_fields=None):
|
||||
"""Returns a list of settings overrides from the infotext parameters dictionary.
|
||||
|
||||
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
|
||||
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
|
||||
|
||||
It checks for conditions before adding an override:
|
||||
- ignores settings that match the current value
|
||||
- ignores parameter keys present in skip_fields argument.
|
||||
|
||||
Example input:
|
||||
{"Clip skip": "2"}
|
||||
|
||||
Example output:
|
||||
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
|
||||
"""
|
||||
|
||||
res = []
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in (skip_fields or {}):
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name in ["sd_model_checkpoint", "forge_additional_modules"]:
|
||||
if shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
res.append((param_name, setting_name, v))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
params = parse_generation_parameters(prompt)
|
||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||
res = []
|
||||
|
||||
for output, key in paste_fields:
|
||||
if callable(key):
|
||||
try:
|
||||
v = key(params)
|
||||
except Exception:
|
||||
errors.report(f"Error executing {key}", exc_info=True)
|
||||
v = None
|
||||
else:
|
||||
v = params.get(key, None)
|
||||
|
||||
if v is None:
|
||||
res.append(gr.update())
|
||||
elif isinstance(v, type_of_gr_update):
|
||||
res.append(v)
|
||||
else:
|
||||
try:
|
||||
valtype = type(output.value)
|
||||
|
||||
if valtype == bool and v == "False":
|
||||
val = False
|
||||
elif valtype == int:
|
||||
val = float(v)
|
||||
else:
|
||||
val = valtype(v)
|
||||
|
||||
res.append(gr.update(value=val))
|
||||
except Exception:
|
||||
res.append(gr.update())
|
||||
|
||||
return res
|
||||
|
||||
if override_settings_component is not None:
|
||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||
|
||||
def paste_settings(params):
|
||||
vals = get_override_settings(params, skip_fields=already_handled_fields)
|
||||
|
||||
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
||||
|
||||
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
show_progress=False,
|
||||
)
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"recalculate_prompts_{tabname}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
46
modules/infotext_versions.py
Executable file
46
modules/infotext_versions.py
Executable file
@@ -0,0 +1,46 @@
|
||||
from modules import shared
|
||||
from packaging import version
|
||||
import re
|
||||
|
||||
|
||||
v160 = version.parse("1.6.0")
|
||||
v170_tsnr = version.parse("v1.7.0-225")
|
||||
v180 = version.parse("1.8.0")
|
||||
v180_hr_styles = version.parse("1.8.0-139")
|
||||
|
||||
|
||||
def parse_version(text):
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
m = re.match(r'([^-]+-[^-]+)-.*', text)
|
||||
if m:
|
||||
text = m.group(1)
|
||||
|
||||
try:
|
||||
return version.parse(text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def backcompat(d):
|
||||
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
|
||||
|
||||
if not shared.opts.auto_backcompat:
|
||||
return
|
||||
|
||||
ver = parse_version(d.get("Version"))
|
||||
if ver is None:
|
||||
return
|
||||
|
||||
if ver < v160 and '[' in d.get('Prompt', ''):
|
||||
d["Old prompt editing timelines"] = True
|
||||
|
||||
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
||||
d["Pad conds v0"] = True
|
||||
|
||||
if ver < v170_tsnr:
|
||||
d["Downcast alphas_cumprod"] = True
|
||||
|
||||
if ver < v180 and d.get('Refiner'):
|
||||
d["Refiner switch by sampling steps"] = True
|
||||
136
modules/initialize.py
Executable file
136
modules/initialize.py
Executable file
@@ -0,0 +1,136 @@
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
import pytorch_lightning # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
from modules import shared_init
|
||||
shared_init.initialize()
|
||||
startup_timer.record("initialize shared")
|
||||
|
||||
from modules import processing, gradio_extensions, ui # noqa: F401
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def check_versions():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if not cmd_opts.skip_version_check:
|
||||
from modules import errors
|
||||
errors.check_versions()
|
||||
|
||||
|
||||
def initialize():
|
||||
from modules import initialize_util
|
||||
initialize_util.fix_torch_version()
|
||||
initialize_util.fix_pytorch_lightning()
|
||||
initialize_util.fix_asyncio_event_loop_policy()
|
||||
initialize_util.validate_tls_options()
|
||||
initialize_util.configure_sigint_handler()
|
||||
initialize_util.configure_opts_onchange()
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.setup_model()
|
||||
startup_timer.record("setup SD model")
|
||||
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
from modules import codeformer_model
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
||||
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
||||
startup_timer.record("setup codeformer")
|
||||
|
||||
from modules import gfpgan_model
|
||||
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
||||
startup_timer.record("setup gfpgan")
|
||||
|
||||
initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
from modules import sd_samplers
|
||||
sd_samplers.set_samplers()
|
||||
startup_timer.record("set samplers")
|
||||
|
||||
from modules import extensions
|
||||
extensions.list_extensions()
|
||||
startup_timer.record("list extensions")
|
||||
|
||||
from modules import initialize_util
|
||||
initialize_util.restore_config_state_file()
|
||||
startup_timer.record("restore config state file")
|
||||
|
||||
from modules import shared, upscaler, scripts
|
||||
if cmd_opts.ui_debug_mode:
|
||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||
scripts.load_scripts()
|
||||
return
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.list_models()
|
||||
startup_timer.record("list SD models")
|
||||
|
||||
from modules import localization
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
startup_timer.record("list localizations")
|
||||
|
||||
with startup_timer.subcategory("load scripts"):
|
||||
scripts.load_scripts()
|
||||
|
||||
if reload_script_modules and shared.opts.enable_reloading_ui_scripts:
|
||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||
importlib.reload(module)
|
||||
startup_timer.record("reload script modules")
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.load_upscalers()
|
||||
startup_timer.record("load upscalers")
|
||||
|
||||
from modules import sd_vae
|
||||
sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
from modules import sd_unet
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
from modules import shared_items
|
||||
shared_items.reload_hypernetworks()
|
||||
startup_timer.record("reload hypernetworks")
|
||||
|
||||
from modules import ui_extra_networks
|
||||
ui_extra_networks.initialize()
|
||||
ui_extra_networks.register_default_pages()
|
||||
|
||||
from modules import extra_networks
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_default_extra_networks()
|
||||
startup_timer.record("initialize extra networks")
|
||||
|
||||
return
|
||||
218
modules/initialize_util.py
Executable file
218
modules/initialize_util.py
Executable file
@@ -0,0 +1,218 @@
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import re
|
||||
|
||||
import starlette
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
def gradio_server_name():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if cmd_opts.server_name:
|
||||
return cmd_opts.server_name
|
||||
else:
|
||||
return "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def fix_torch_version():
|
||||
import torch
|
||||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
def fix_pytorch_lightning():
|
||||
# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
|
||||
if 'pytorch_lightning.utilities.distributed' not in sys.modules:
|
||||
import pytorch_lightning
|
||||
# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
|
||||
print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
|
||||
sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
|
||||
|
||||
def fix_asyncio_event_loop_policy():
|
||||
"""
|
||||
The default `asyncio` event loop policy only automatically creates
|
||||
event loops in the main threads. Other threads must create event
|
||||
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||
loops to be created automatically on any thread, matching the
|
||||
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||
# interface for composing policies so pick the right base.
|
||||
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||
else:
|
||||
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||
|
||||
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||
"""Event loop policy that allows loop creation on any thread.
|
||||
Usage::
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
"""
|
||||
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
return super().get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||
# and changed to a RuntimeError in 3.4.3.
|
||||
# "There is no current event loop in thread %r"
|
||||
loop = self.new_event_loop()
|
||||
self.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
|
||||
|
||||
def restore_config_state_file():
|
||||
from modules import shared, config_states
|
||||
|
||||
config_state_file = shared.opts.restore_config_state_file
|
||||
if config_state_file == "":
|
||||
return
|
||||
|
||||
shared.opts.restore_config_state_file = ""
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
if os.path.isfile(config_state_file):
|
||||
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||
config_state = json.load(f)
|
||||
config_states.restore_extension_config(config_state)
|
||||
startup_timer.record("restore extension config")
|
||||
elif config_state_file:
|
||||
print(f"!!! Config state backup not found: {config_state_file}")
|
||||
|
||||
|
||||
def validate_tls_options():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
||||
return
|
||||
|
||||
try:
|
||||
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||
print("Invalid path to TLS keyfile given")
|
||||
if not os.path.exists(cmd_opts.tls_certfile):
|
||||
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||
except TypeError:
|
||||
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||
print("TLS setup invalid, running webui without TLS")
|
||||
else:
|
||||
print("Running with TLS")
|
||||
startup_timer.record("TLS")
|
||||
|
||||
|
||||
def get_gradio_auth_creds():
|
||||
"""
|
||||
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
||||
an iterable of (username, password) tuples.
|
||||
"""
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
def process_credential_line(s):
|
||||
s = s.strip()
|
||||
if not s:
|
||||
return None
|
||||
return tuple(s.split(':', 1))
|
||||
|
||||
if cmd_opts.gradio_auth:
|
||||
for cred in cmd_opts.gradio_auth.split(','):
|
||||
cred = process_credential_line(cred)
|
||||
if cred:
|
||||
yield cred
|
||||
|
||||
if cmd_opts.gradio_auth_path:
|
||||
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||
for line in file.readlines():
|
||||
for cred in line.strip().split(','):
|
||||
cred = process_credential_line(cred)
|
||||
if cred:
|
||||
yield cred
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
|
||||
from modules import shared
|
||||
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
|
||||
if shared.opts.dump_stacks_on_signal:
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
if not os.environ.get("COVERAGE_RUN"):
|
||||
# Don't install the immediate-quit handler when running under coverage,
|
||||
# as then the coverage report won't be generated.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def configure_opts_onchange():
|
||||
from modules import shared, sd_models, sd_vae, ui_tempdir
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules_forge import main_thread
|
||||
|
||||
# shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
|
||||
# shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
# shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
# shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
# shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.user_middleware.insert(0, starlette.middleware.Middleware(GZipMiddleware, minimum_size=1000))
|
||||
configure_cors_middleware(app)
|
||||
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||
|
||||
|
||||
def configure_cors_middleware(app):
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"allow_credentials": True,
|
||||
}
|
||||
if cmd_opts.cors_allow_origins:
|
||||
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
||||
if cmd_opts.cors_allow_origins_regex:
|
||||
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
||||
|
||||
app.user_middleware.insert(0, starlette.middleware.Middleware(CORSMiddleware, **cors_options))
|
||||
|
||||
219
modules/interrogate.py
Executable file
219
modules/interrogate.py
Executable file
@@ -0,0 +1,219 @@
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.hub
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, modelloader, errors
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)$")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
|
||||
|
||||
def download_default_clip_interrogate_categories(content_dir):
|
||||
print("Downloading CLIP categories...")
|
||||
|
||||
tmpdir = f"{content_dir}_tmp"
|
||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||
|
||||
try:
|
||||
os.makedirs(tmpdir, exist_ok=True)
|
||||
for category_type in category_types:
|
||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||
os.rename(tmpdir, content_dir)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "downloading default CLIP interrogate categories")
|
||||
finally:
|
||||
if os.path.exists(tmpdir):
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
class InterrogateModels:
|
||||
blip_model = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.loaded_categories = None
|
||||
self.skip_categories = []
|
||||
self.content_dir = content_dir
|
||||
|
||||
self.load_device = memory_management.text_encoder_device()
|
||||
self.offload_device = memory_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if memory_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.blip_patcher = None
|
||||
self.clip_patcher = None
|
||||
|
||||
def categories(self):
|
||||
if not os.path.exists(self.content_dir):
|
||||
download_default_clip_interrogate_categories(self.content_dir)
|
||||
|
||||
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
|
||||
return self.loaded_categories
|
||||
|
||||
self.loaded_categories = []
|
||||
|
||||
if os.path.exists(self.content_dir):
|
||||
self.skip_categories = shared.opts.interrogate_clip_skip_categories
|
||||
category_types = []
|
||||
for filename in Path(self.content_dir).glob('*.txt'):
|
||||
category_types.append(filename.stem)
|
||||
if filename.stem in self.skip_categories:
|
||||
continue
|
||||
m = re_topn.search(filename.stem)
|
||||
topn = 1 if m is None else int(m.group(1))
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
lines = [x.strip() for x in file.readlines()]
|
||||
|
||||
self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
|
||||
|
||||
return self.loaded_categories
|
||||
|
||||
def create_fake_fairscale(self):
|
||||
class FakeFairscale:
|
||||
def checkpoint_wrapper(self):
|
||||
pass
|
||||
|
||||
sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
|
||||
|
||||
def load_blip_model(self):
|
||||
self.create_fake_fairscale()
|
||||
import models.blip
|
||||
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "BLIP"),
|
||||
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
||||
ext_filter=[".pth"],
|
||||
download_name='model_base_caption_capfilt_large.pth',
|
||||
)
|
||||
|
||||
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||
blip_model.eval()
|
||||
|
||||
return blip_model
|
||||
|
||||
def load_clip_model(self):
|
||||
import clip
|
||||
import clip.model
|
||||
|
||||
clip.model.LayerNorm = torch.nn.LayerNorm
|
||||
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
model.eval()
|
||||
|
||||
return model, preprocess
|
||||
|
||||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
self.blip_model = self.blip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||
self.blip_patcher = ModelPatcher(self.blip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
memory_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
||||
return
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
pass
|
||||
|
||||
def send_blip_to_ram(self):
|
||||
pass
|
||||
|
||||
def unload(self):
|
||||
pass
|
||||
|
||||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if shared.opts.interrogate_clip_dict_limit != 0:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize(list(text_array), truncate=True).to(self.load_device)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(self.load_device)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
|
||||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||
|
||||
def generate_caption(self, pil_image):
|
||||
gpu_image = transforms.Compose([
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=int(shared.opts.interrogate_clip_num_beams), min_length=int(shared.opts.interrogate_clip_min_length), max_length=shared.opts.interrogate_clip_max_length)
|
||||
|
||||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = ""
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
self.load()
|
||||
|
||||
caption = self.generate_caption(pil_image)
|
||||
self.send_blip_to_ram()
|
||||
devices.torch_gc()
|
||||
|
||||
res = caption
|
||||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
for cat in self.categories():
|
||||
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||
for match, score in matches:
|
||||
if shared.opts.interrogate_return_ranks:
|
||||
res += f", ({match}:{score/100:.3f})"
|
||||
else:
|
||||
res += f", {match}"
|
||||
|
||||
except Exception:
|
||||
errors.report("Error interrogating", exc_info=True)
|
||||
res += "<error>"
|
||||
|
||||
self.unload()
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
568
modules/launch_utils.py
Executable file
568
modules/launch_utils.py
Executable file
@@ -0,0 +1,568 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
import platform
|
||||
import json
|
||||
import shlex
|
||||
from functools import lru_cache
|
||||
from typing import NamedTuple
|
||||
from pathlib import Path
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir, extensions_builtin_dir
|
||||
from modules.timer import startup_timer
|
||||
from modules import logging_config
|
||||
from modules_forge import forge_version
|
||||
from modules_forge.config import always_disabled_extensions
|
||||
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
logging_config.setup_logging(args.loglevel)
|
||||
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
dir_repos = "repositories"
|
||||
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
|
||||
|
||||
def check_python_version():
|
||||
is_windows = platform.system() == "Windows"
|
||||
major = sys.version_info.major
|
||||
minor = sys.version_info.minor
|
||||
micro = sys.version_info.micro
|
||||
|
||||
if is_windows:
|
||||
supported_minors = [10]
|
||||
else:
|
||||
supported_minors = [7, 8, 9, 10, 11]
|
||||
|
||||
if not (major == 3 and minor in supported_minors):
|
||||
import modules.errors
|
||||
|
||||
modules.errors.print_error_explanation(f"""
|
||||
INCOMPATIBLE PYTHON VERSION
|
||||
|
||||
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
||||
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
||||
or any other error regarding unsuccessful package (library) installation,
|
||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||
and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def commit_hash():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def git_tag_a1111():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
|
||||
changelog_md = os.path.join(script_path, "CHANGELOG.md")
|
||||
with open(changelog_md, "r", encoding="utf-8") as file:
|
||||
line = next((line.strip() for line in file if line.strip()), "<none>")
|
||||
line = line.replace("## ", "")
|
||||
return line
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
|
||||
def git_tag():
|
||||
return 'f' + forge_version.version + '-' + git_tag_a1111()
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
run_kwargs = {
|
||||
"args": command,
|
||||
"shell": True,
|
||||
"env": os.environ if custom_env is None else custom_env,
|
||||
"encoding": 'utf8',
|
||||
"errors": 'ignore',
|
||||
}
|
||||
|
||||
if not live:
|
||||
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||
|
||||
result = subprocess.run(**run_kwargs)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_bits = [
|
||||
f"{errdesc or 'Error running command'}.",
|
||||
f"Command: {command}",
|
||||
f"Error code: {result.returncode}",
|
||||
]
|
||||
if result.stdout:
|
||||
error_bits.append(f"stdout: {result.stdout}")
|
||||
if result.stderr:
|
||||
error_bits.append(f"stderr: {result.stderr}")
|
||||
raise RuntimeError("\n".join(error_bits))
|
||||
|
||||
return (result.stdout or "")
|
||||
|
||||
|
||||
def is_installed(package):
|
||||
try:
|
||||
dist = importlib.metadata.distribution(package)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
try:
|
||||
spec = importlib.util.find_spec(package)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
return spec is not None
|
||||
|
||||
return dist is not None
|
||||
|
||||
|
||||
def repo_dir(name):
|
||||
return os.path.join(script_path, dir_repos, name)
|
||||
|
||||
|
||||
def run_pip(command, desc=None, live=default_command_live):
|
||||
if args.skip_install:
|
||||
return
|
||||
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||
|
||||
|
||||
def check_run_python(code: str) -> bool:
|
||||
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_fix_workspace(dir, name):
|
||||
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||
return
|
||||
|
||||
|
||||
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||
try:
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
except RuntimeError:
|
||||
if not autofix:
|
||||
raise
|
||||
|
||||
print(f"{errdesc}, attempting autofix...")
|
||||
git_fix_workspace(dir, name)
|
||||
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
if os.path.exists(dir):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||
|
||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
|
||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
|
||||
|
||||
def git_pull_recursive(dir):
|
||||
for subdir, _, _ in os.walk(dir):
|
||||
if os.path.exists(os.path.join(subdir, '.git')):
|
||||
try:
|
||||
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
||||
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
||||
|
||||
|
||||
def version_check(commit):
|
||||
try:
|
||||
import requests
|
||||
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||
print("--------------------------------------------------------")
|
||||
print("| You are not up to date with the most recent release. |")
|
||||
print("| Consider running `git pull` to update. |")
|
||||
print("--------------------------------------------------------")
|
||||
elif commits['commit']['sha'] == commit:
|
||||
print("You are up to date with the most recent release.")
|
||||
else:
|
||||
print("Not a git clone, can't perform version check.")
|
||||
except Exception as e:
|
||||
print("version check failed", e)
|
||||
|
||||
|
||||
def run_extension_installer(extension_dir):
|
||||
path_installer = os.path.join(extension_dir, "install.py")
|
||||
if not os.path.isfile(path_installer):
|
||||
return
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = f"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||
|
||||
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
|
||||
if stdout:
|
||||
print(stdout)
|
||||
except Exception as e:
|
||||
errors.report(str(e))
|
||||
|
||||
|
||||
def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []) + always_disabled_extensions)
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def list_extensions_builtin(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_builtin_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_builtin_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
with startup_timer.subcategory("run extensions installers"):
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
if not os.path.isdir(extensions_builtin_dir):
|
||||
return
|
||||
|
||||
with startup_timer.subcategory("run extensions_builtin installers"):
|
||||
for dirname_extension in list_extensions_builtin(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_builtin_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
return
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
|
||||
|
||||
def requirements_met(requirements_file):
|
||||
"""
|
||||
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
||||
are already installed. Returns True if so, False if not installed or parsing fails.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import packaging.version
|
||||
|
||||
with open(requirements_file, "r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
|
||||
m = re.match(re_requirement, line)
|
||||
if m is None:
|
||||
return False
|
||||
|
||||
package = m.group(1).strip()
|
||||
version_required = (m.group(2) or "").strip()
|
||||
|
||||
if version_required == "":
|
||||
continue
|
||||
|
||||
try:
|
||||
version_installed = importlib.metadata.version(package)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.3.1 torchvision==0.18.1 --extra-index-url {torch_index_url}")
|
||||
if args.use_ipex:
|
||||
if platform.system() == "Windows":
|
||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||
# This is NOT an Intel official release so please use it at your own risk!!
|
||||
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
|
||||
#
|
||||
# Strengths (over official IPEX 2.0.110 windows release):
|
||||
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
|
||||
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
|
||||
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
|
||||
# Limitation:
|
||||
# - Only works for python 3.10
|
||||
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
|
||||
else:
|
||||
# Using official IPEX release for linux since it's already an AOT build.
|
||||
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
|
||||
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.27')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
# stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
# stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
# k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
# stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
# stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
# k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "84826248b49bb7ca754c73293299c4d4e23a548d")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
startup_timer.record("checks")
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
startup_timer.record("git version info")
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
startup_timer.record("install torch")
|
||||
|
||||
if args.use_ipex:
|
||||
args.skip_torch_cuda_test = True
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Your device does not support the current version of Torch/CUDA! Consider download another version: \n'
|
||||
'https://github.com/lllyasviel/stable-diffusion-webui-forge/releases/tag/latest'
|
||||
# 'Torch is not able to use GPU; '
|
||||
# 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
startup_timer.record("torch GPU test")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
startup_timer.record("install clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
startup_timer.record("install open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
startup_timer.record("install xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
startup_timer.record("install ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
# git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
# git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
# git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
if not requirements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file_for_npu):
|
||||
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||
|
||||
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||
startup_timer.record("install requirements_for_npu")
|
||||
|
||||
if not args.skip_install:
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
startup_timer.record("check version")
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
startup_timer.record("update extensions")
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
if "--ckpt" not in sys.argv:
|
||||
sys.argv.append("--ckpt")
|
||||
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
if "--disable-nan-check" not in sys.argv:
|
||||
sys.argv.append("--disable-nan-check")
|
||||
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
|
||||
|
||||
def configure_forge_reference_checkout(a1111_home: Path):
|
||||
"""Set model paths based on an existing A1111 checkout."""
|
||||
class ModelRef(NamedTuple):
|
||||
arg_name: str
|
||||
relative_path: str
|
||||
|
||||
refs = [
|
||||
ModelRef(arg_name="--ckpt-dir", relative_path="models/Stable-diffusion"),
|
||||
ModelRef(arg_name="--vae-dir", relative_path="models/VAE"),
|
||||
ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"),
|
||||
ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"),
|
||||
ModelRef(arg_name="--lora-dir", relative_path="models/lora"),
|
||||
# Ref A1111 need to have sd-webui-controlnet installed.
|
||||
ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"),
|
||||
ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"),
|
||||
]
|
||||
|
||||
for ref in refs:
|
||||
target_path = a1111_home / ref.relative_path
|
||||
if not target_path.exists():
|
||||
print(f"Path {target_path} does not exist. Skip setting {ref.arg_name}")
|
||||
continue
|
||||
|
||||
if ref.arg_name in sys.argv:
|
||||
# Do not override existing dir setting.
|
||||
continue
|
||||
|
||||
sys.argv.append(ref.arg_name)
|
||||
sys.argv.append(str(target_path))
|
||||
|
||||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
from modules_forge import main_thread
|
||||
|
||||
main_thread.loop()
|
||||
return
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
import datetime
|
||||
|
||||
text = sysinfo.get()
|
||||
filename = f"sysinfo-{datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d-%H-%M')}.json"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
file.write(text)
|
||||
|
||||
return filename
|
||||
37
modules/localization.py
Executable file
37
modules/localization.py
Executable file
@@ -0,0 +1,37 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from modules import errors, scripts
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
def list_localizations(dirname):
|
||||
localizations.clear()
|
||||
|
||||
for file in os.listdir(dirname):
|
||||
fn, ext = os.path.splitext(file)
|
||||
if ext.lower() != ".json":
|
||||
continue
|
||||
|
||||
localizations[fn] = [os.path.join(dirname, file)]
|
||||
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
if fn not in localizations:
|
||||
localizations[fn] = []
|
||||
localizations[fn].append(file.path)
|
||||
|
||||
|
||||
def localization_js(current_localization_name: str) -> str:
|
||||
fns = localizations.get(current_localization_name, None)
|
||||
data = {}
|
||||
if fns is not None:
|
||||
for fn in fns:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data.update(json.load(file))
|
||||
except Exception:
|
||||
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||
|
||||
return f"window.localization = {json.dumps(data)}"
|
||||
58
modules/logging_config.py
Executable file
58
modules/logging_config.py
Executable file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
def __init__(self, fallback_handler: logging.Handler):
|
||||
super().__init__()
|
||||
self.fallback_handler = fallback_handler
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
# If there are active tqdm progress bars,
|
||||
# attempt to not interfere with them.
|
||||
if tqdm._instances:
|
||||
tqdm.write(self.format(record))
|
||||
else:
|
||||
self.fallback_handler.emit(record)
|
||||
except Exception:
|
||||
self.fallback_handler.emit(record)
|
||||
|
||||
except ImportError:
|
||||
TqdmLoggingHandler = None
|
||||
|
||||
|
||||
def setup_logging(loglevel):
|
||||
if loglevel is None:
|
||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
|
||||
if not loglevel:
|
||||
return
|
||||
|
||||
if logging.root.handlers:
|
||||
# Already configured, do not interfere
|
||||
return
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
if os.environ.get("SD_WEBUI_RICH_LOG"):
|
||||
from rich.logging import RichHandler
|
||||
handler = RichHandler()
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
if TqdmLoggingHandler:
|
||||
handler = TqdmLoggingHandler(handler)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.root.setLevel(log_level)
|
||||
logging.root.addHandler(handler)
|
||||
28
modules/lowvram.py
Executable file
28
modules/lowvram.py
Executable file
@@ -0,0 +1,28 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from modules import devices, shared
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
||||
|
||||
def send_everything_to_cpu():
|
||||
return
|
||||
|
||||
|
||||
def is_needed(sd_model):
|
||||
return False
|
||||
|
||||
|
||||
def apply(sd_model):
|
||||
return
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
return
|
||||
|
||||
|
||||
def is_enabled(sd_model):
|
||||
return False
|
||||
98
modules/mac_specific.py
Executable file
98
modules/mac_specific.py
Executable file
@@ -0,0 +1,98 @@
|
||||
# import logging
|
||||
#
|
||||
# import torch
|
||||
# from torch import Tensor
|
||||
# import platform
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
# from packaging import version
|
||||
# from modules import shared
|
||||
#
|
||||
# log = logging.getLogger(__name__)
|
||||
#
|
||||
#
|
||||
# # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# # use check `getattr` and try it for compatibility.
|
||||
# # in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||
# # since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
# def check_for_mps() -> bool:
|
||||
# if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
# if not getattr(torch, 'has_mps', False):
|
||||
# return False
|
||||
# try:
|
||||
# torch.zeros(1).to(torch.device("mps"))
|
||||
# return True
|
||||
# except Exception:
|
||||
# return False
|
||||
# else:
|
||||
# return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||
#
|
||||
#
|
||||
# has_mps = check_for_mps()
|
||||
#
|
||||
#
|
||||
# def torch_mps_gc() -> None:
|
||||
# try:
|
||||
# if shared.state.current_latent is not None:
|
||||
# log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||
# return
|
||||
# from torch.mps import empty_cache
|
||||
# empty_cache()
|
||||
# except Exception:
|
||||
# log.warning("MPS garbage collection failed", exc_info=True)
|
||||
#
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
# def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
# if input.device.type == 'mps':
|
||||
# output_dtype = kwargs.get('dtype', input.dtype)
|
||||
# if output_dtype == torch.int64:
|
||||
# return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
# elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
# return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
# return cumsum_func(input, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
# def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||
# try:
|
||||
# return orig_func(*args, **kwargs)
|
||||
# except RuntimeError as e:
|
||||
# if "not implemented for" in str(e) and "Half" in str(e):
|
||||
# input_tensor = args[0]
|
||||
# return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||
# else:
|
||||
# print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||
#
|
||||
# if has_mps:
|
||||
# if platform.mac_ver()[0].startswith("13.2."):
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
# CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
#
|
||||
# if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
# CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
# lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
# lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
# CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
# elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
# cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
# cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
# CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
# CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
# CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||
# CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||
#
|
||||
# # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||
# CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||
#
|
||||
# # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||
# if platform.processor() == 'i386':
|
||||
# for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||
# CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||
96
modules/masking.py
Executable file
96
modules/masking.py
Executable file
@@ -0,0 +1,96 @@
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
|
||||
def get_crop_region_v2(mask, pad=0):
|
||||
"""
|
||||
Finds a rectangular region that contains all masked ares in a mask.
|
||||
Returns None if mask is completely black mask (all 0)
|
||||
|
||||
Parameters:
|
||||
mask: PIL.Image.Image L mode or numpy 1d array
|
||||
pad: int number of pixels that the region will be extended on all sides
|
||||
Returns: (x1, y1, x2, y2) | None
|
||||
|
||||
Introduced post 1.9.0
|
||||
"""
|
||||
mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
||||
if box := mask.getbbox():
|
||||
x1, y1, x2, y2 = box
|
||||
return (max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])) if pad else box
|
||||
|
||||
|
||||
def get_crop_region(mask, pad=0):
|
||||
"""
|
||||
Same function as get_crop_region_v2 but handles completely black mask (all 0) differently
|
||||
when mask all black still return coordinates but the coordinates may be invalid ie x2>x1 or y2>y1
|
||||
Notes: it is possible for the coordinates to be "valid" again if pad size is sufficiently large
|
||||
(mask_size.x-pad, mask_size.y-pad, pad, pad)
|
||||
|
||||
Extension developer should use get_crop_region_v2 instead unless for compatibility considerations.
|
||||
"""
|
||||
mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
||||
if box := get_crop_region_v2(mask, pad):
|
||||
return box
|
||||
x1, y1 = mask.size
|
||||
x2 = y2 = 0
|
||||
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])
|
||||
|
||||
|
||||
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
||||
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
|
||||
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
|
||||
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
||||
ratio_processing = processing_width / processing_height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2-y1))
|
||||
y1 -= desired_height_diff//2
|
||||
y2 += desired_height_diff - desired_height_diff//2
|
||||
if y2 >= image_height:
|
||||
diff = y2 - image_height
|
||||
y2 -= diff
|
||||
y1 -= diff
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 -= y1
|
||||
if y2 >= image_height:
|
||||
y2 = image_height
|
||||
else:
|
||||
desired_width = (y2 - y1) * ratio_processing
|
||||
desired_width_diff = int(desired_width - (x2-x1))
|
||||
x1 -= desired_width_diff//2
|
||||
x2 += desired_width_diff - desired_width_diff//2
|
||||
if x2 >= image_width:
|
||||
diff = x2 - image_width
|
||||
x2 -= diff
|
||||
x1 -= diff
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 -= x1
|
||||
if x2 >= image_width:
|
||||
x2 = image_width
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def fill(image, mask):
|
||||
"""fills masked regions with colors from image using blur. Not extremely effective."""
|
||||
|
||||
image_mod = Image.new('RGBA', (image.width, image.height))
|
||||
|
||||
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
||||
|
||||
image_masked = image_masked.convert('RGBa')
|
||||
|
||||
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
||||
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
||||
for _ in range(repeats):
|
||||
image_mod.alpha_composite(blurred)
|
||||
|
||||
return image_mod.convert("RGB")
|
||||
|
||||
92
modules/memmon.py
Executable file
92
modules/memmon.py
Executable file
@@ -0,0 +1,92 @@
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MemUsageMonitor(threading.Thread):
|
||||
run_flag = None
|
||||
device = None
|
||||
disabled = False
|
||||
opts = None
|
||||
data = None
|
||||
|
||||
def __init__(self, name, device, opts):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.opts = opts
|
||||
|
||||
self.daemon = True
|
||||
self.run_flag = threading.Event()
|
||||
self.data = defaultdict(int)
|
||||
|
||||
try:
|
||||
self.cuda_mem_get_info()
|
||||
torch.cuda.memory_stats(self.device)
|
||||
except Exception as e: # AMD or whatever
|
||||
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
||||
self.disabled = True
|
||||
|
||||
def cuda_mem_get_info(self):
|
||||
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
||||
return torch.cuda.mem_get_info(index)
|
||||
|
||||
def run(self):
|
||||
if self.disabled:
|
||||
return
|
||||
|
||||
while True:
|
||||
self.run_flag.wait()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
self.data.clear()
|
||||
|
||||
if self.opts.memmon_poll_rate <= 0:
|
||||
self.run_flag.clear()
|
||||
continue
|
||||
|
||||
self.data["min_free"] = self.cuda_mem_get_info()[0]
|
||||
|
||||
while self.run_flag.is_set():
|
||||
free, total = self.cuda_mem_get_info()
|
||||
self.data["min_free"] = min(self.data["min_free"], free)
|
||||
|
||||
time.sleep(1 / self.opts.memmon_poll_rate)
|
||||
|
||||
def dump_debug(self):
|
||||
print(self, 'recorded data:')
|
||||
for k, v in self.read().items():
|
||||
print(k, -(v // -(1024 ** 2)))
|
||||
|
||||
print(self, 'raw torch memory stats:')
|
||||
tm = torch.cuda.memory_stats(self.device)
|
||||
for k, v in tm.items():
|
||||
if 'bytes' not in k:
|
||||
continue
|
||||
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
||||
|
||||
print(torch.cuda.memory_summary())
|
||||
|
||||
def monitor(self):
|
||||
self.run_flag.set()
|
||||
|
||||
def read(self):
|
||||
if not self.disabled:
|
||||
free, total = self.cuda_mem_get_info()
|
||||
self.data["free"] = free
|
||||
self.data["total"] = total
|
||||
|
||||
torch_stats = torch.cuda.memory_stats(self.device)
|
||||
self.data["active"] = torch_stats["active.all.current"]
|
||||
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
||||
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||
self.data["system_peak"] = total - self.data["min_free"]
|
||||
|
||||
return self.data
|
||||
|
||||
def stop(self):
|
||||
self.run_flag.clear()
|
||||
return self.read()
|
||||
174
modules/modelloader.py
Executable file
174
modules/modelloader.py
Executable file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.util import load_file_from_url # noqa, backwards compatibility
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import spandrel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
@param download_name: Specify to download from model_url immediately.
|
||||
@param model_url: If no other models are found, this will be downloaded on upscale.
|
||||
@param model_path: The location to store/find models in.
|
||||
@param command_path: A command-line argument to search for models in first.
|
||||
@param ext_filter: An optional list of filename extensions to filter by
|
||||
@param hash_prefix: the expected sha256 of the model_url
|
||||
@return: A list of paths containing the desired model(s)
|
||||
"""
|
||||
output = []
|
||||
|
||||
try:
|
||||
places = []
|
||||
|
||||
if command_path is not None and command_path != model_path:
|
||||
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||
if os.path.exists(pretrained_path):
|
||||
print(f"Appending path: {pretrained_path}")
|
||||
places.append(pretrained_path)
|
||||
elif os.path.exists(command_path):
|
||||
places.append(command_path)
|
||||
|
||||
places.append(model_path)
|
||||
|
||||
for place in places:
|
||||
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||
continue
|
||||
if full_path not in output:
|
||||
output.append(full_path)
|
||||
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix))
|
||||
else:
|
||||
output.append(model_url)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
if file.startswith("http"):
|
||||
file = urlparse(file).path
|
||||
|
||||
file = os.path.basename(file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
return model_name
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
modules_dir = os.path.join(shared.script_path, "modules")
|
||||
for file in os.listdir(modules_dir):
|
||||
if "_model.py" in file:
|
||||
model_name = file.replace("_model.py", "")
|
||||
full_model = f"modules.{model_name}_model"
|
||||
try:
|
||||
importlib.import_module(full_model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
data = []
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
|
||||
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||
# up with two copies of those classes. The newest copy will always be the last in the list,
|
||||
# so we go from end to beginning and ignore duplicates
|
||||
used_classes = {}
|
||||
for cls in reversed(Upscaler.__subclasses__()):
|
||||
classname = str(cls)
|
||||
if classname not in used_classes:
|
||||
used_classes[classname] = cls
|
||||
|
||||
for cls in reversed(used_classes.values()):
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
commandline_model_path = commandline_options.get(cmd_name, None)
|
||||
scaler = cls(commandline_model_path)
|
||||
scaler.user_path = commandline_model_path
|
||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||
data += scaler.scalers
|
||||
|
||||
shared.sd_upscalers = sorted(
|
||||
data,
|
||||
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||
)
|
||||
|
||||
# None: not loaded, False: failed to load, True: loaded
|
||||
_spandrel_extra_init_state = None
|
||||
|
||||
|
||||
def _init_spandrel_extra_archs() -> None:
|
||||
"""
|
||||
Try to initialize `spandrel_extra_archs` (exactly once).
|
||||
"""
|
||||
global _spandrel_extra_init_state
|
||||
if _spandrel_extra_init_state is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
import spandrel
|
||||
import spandrel_extra_arches
|
||||
spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
|
||||
_spandrel_extra_init_state = True
|
||||
except Exception:
|
||||
logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
|
||||
_spandrel_extra_init_state = False
|
||||
|
||||
|
||||
def load_spandrel_model(
|
||||
path: str | os.PathLike,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
prefer_half: bool = False,
|
||||
dtype: str | torch.dtype | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
) -> spandrel.ModelDescriptor:
|
||||
global _spandrel_extra_init_state
|
||||
|
||||
import spandrel
|
||||
_init_spandrel_extra_archs()
|
||||
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||
arch = model_descriptor.architecture
|
||||
if expected_architecture and arch.name != expected_architecture:
|
||||
logger.warning(
|
||||
f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
|
||||
)
|
||||
half = False
|
||||
if prefer_half:
|
||||
if model_descriptor.supports_half:
|
||||
model_descriptor.model.half()
|
||||
half = True
|
||||
else:
|
||||
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||
if dtype:
|
||||
model_descriptor.model.to(dtype=dtype)
|
||||
model_descriptor.model.eval()
|
||||
logger.debug(
|
||||
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||
arch, path, device, half, dtype,
|
||||
)
|
||||
return model_descriptor
|
||||
1460
modules/models/diffusion/ddpm_edit.py
Executable file
1460
modules/models/diffusion/ddpm_edit.py
Executable file
File diff suppressed because it is too large
Load Diff
1
modules/models/diffusion/uni_pc/__init__.py
Executable file
1
modules/models/diffusion/uni_pc/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from .sampler import UniPCSampler # noqa: F401
|
||||
101
modules/models/diffusion/uni_pc/sampler.py
Executable file
101
modules/models/diffusion/uni_pc/sampler.py
Executable file
@@ -0,0 +1,101 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
|
||||
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
class UniPCSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.before_sample = None
|
||||
self.after_sample = None
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(devices.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def set_hooks(self, before_sample, after_sample, after_update):
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
# print(f'Data shape for UniPC sampling is {size}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
# SD 1.X is "noise", SD 2.X is "v"
|
||||
model_type = "v" if self.model.parameterization == "v" else "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="classifier-free",
|
||||
#condition=conditioning,
|
||||
#unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||
x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
|
||||
return x.to(device), None
|
||||
863
modules/models/diffusion/uni_pc/uni_pc.py
Executable file
863
modules/models/diffusion/uni_pc/uni_pc.py
Executable file
@@ -0,0 +1,863 @@
|
||||
import torch
|
||||
import math
|
||||
import tqdm
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
self,
|
||||
schedule='discrete',
|
||||
betas=None,
|
||||
alphas_cumprod=None,
|
||||
continuous_beta_0=0.1,
|
||||
continuous_beta_1=20.,
|
||||
):
|
||||
"""Create a wrapper class for the forward SDE (VP type).
|
||||
|
||||
***
|
||||
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
||||
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
||||
***
|
||||
|
||||
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
||||
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
||||
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
||||
|
||||
log_alpha_t = self.marginal_log_mean_coeff(t)
|
||||
sigma_t = self.marginal_std(t)
|
||||
lambda_t = self.marginal_lambda(t)
|
||||
|
||||
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
||||
|
||||
t = self.inverse_lambda(lambda_t)
|
||||
|
||||
===============================================================
|
||||
|
||||
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
||||
|
||||
1. For discrete-time DPMs:
|
||||
|
||||
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
||||
t_i = (i + 1) / N
|
||||
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
||||
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
||||
|
||||
Args:
|
||||
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
||||
|
||||
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
||||
|
||||
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
||||
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
||||
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
||||
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
||||
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
||||
and
|
||||
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
||||
|
||||
|
||||
2. For continuous-time DPMs:
|
||||
|
||||
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
||||
schedule are the default settings in DDPM and improved-DDPM:
|
||||
|
||||
Args:
|
||||
beta_min: A `float` number. The smallest beta for the linear schedule.
|
||||
beta_max: A `float` number. The largest beta for the linear schedule.
|
||||
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
||||
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
||||
T: A `float` number. The ending time of the forward process.
|
||||
|
||||
===============================================================
|
||||
|
||||
Args:
|
||||
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
===============================================================
|
||||
|
||||
Example:
|
||||
|
||||
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
||||
|
||||
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
||||
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
|
||||
# For continuous-time DPMs (VPSDE), linear schedule:
|
||||
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
||||
|
||||
"""
|
||||
|
||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule == 'discrete':
|
||||
if betas is not None:
|
||||
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
||||
else:
|
||||
assert alphas_cumprod is not None
|
||||
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||
self.total_N = len(log_alphas)
|
||||
self.T = 1.
|
||||
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
|
||||
self.log_alpha_array = log_alphas.reshape((1, -1,))
|
||||
else:
|
||||
self.total_N = 1000
|
||||
self.beta_0 = continuous_beta_0
|
||||
self.beta_1 = continuous_beta_1
|
||||
self.cosine_s = 0.008
|
||||
self.cosine_beta_max = 999.
|
||||
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
||||
self.schedule = schedule
|
||||
if schedule == 'cosine':
|
||||
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
||||
self.T = 0.9946
|
||||
else:
|
||||
self.T = 1.
|
||||
|
||||
def marginal_log_mean_coeff(self, t):
|
||||
"""
|
||||
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
if self.schedule == 'discrete':
|
||||
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
|
||||
elif self.schedule == 'linear':
|
||||
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
||||
elif self.schedule == 'cosine':
|
||||
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
|
||||
def marginal_alpha(self, t):
|
||||
"""
|
||||
Compute alpha_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||
|
||||
def marginal_std(self, t):
|
||||
"""
|
||||
Compute sigma_t of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||
|
||||
def marginal_lambda(self, t):
|
||||
"""
|
||||
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||
"""
|
||||
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def inverse_lambda(self, lamb):
|
||||
"""
|
||||
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||
"""
|
||||
if self.schedule == 'linear':
|
||||
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == 'discrete':
|
||||
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
||||
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
|
||||
return t.reshape((-1,))
|
||||
else:
|
||||
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
||||
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
||||
t = t_fn(log_alpha)
|
||||
return t
|
||||
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs=None,
|
||||
guidance_type="uncond",
|
||||
#condition=None,
|
||||
#unconditional_condition=None,
|
||||
guidance_scale=1.,
|
||||
classifier_fn=None,
|
||||
classifier_kwargs=None,
|
||||
):
|
||||
"""Create a wrapper function for the noise prediction model.
|
||||
|
||||
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
||||
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
||||
|
||||
We support four types of the diffusion model by setting `model_type`:
|
||||
|
||||
1. "noise": noise prediction model. (Trained by predicting noise).
|
||||
|
||||
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
||||
|
||||
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
||||
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
||||
|
||||
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
noise(x_t, t) = -sigma_t * score(x_t, t)
|
||||
```
|
||||
|
||||
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
||||
1. "uncond": unconditional sampling by DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
|
||||
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
|
||||
The input `classifier_fn` has the following format:
|
||||
``
|
||||
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
||||
``
|
||||
|
||||
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
||||
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
||||
|
||||
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
||||
The input `model` has the following format:
|
||||
``
|
||||
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
||||
``
|
||||
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
||||
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
|
||||
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
||||
``
|
||||
def model_fn(x, t_continuous) -> noise:
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
return noise_pred(model, x, t_input, **model_kwargs)
|
||||
``
|
||||
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
||||
|
||||
===============================================================
|
||||
|
||||
Args:
|
||||
model: A diffusion model with the corresponding format described above.
|
||||
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
||||
model_type: A `str`. The parameterization type of the diffusion model.
|
||||
"noise" or "x_start" or "v" or "score".
|
||||
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
||||
guidance_type: A `str`. The type of the guidance for sampling.
|
||||
"uncond" or "classifier" or "classifier-free".
|
||||
condition: A pytorch tensor. The condition for the guided sampling.
|
||||
Only used for "classifier" or "classifier-free" guidance type.
|
||||
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
||||
Only used for "classifier-free" guidance type.
|
||||
guidance_scale: A `float`. The scale for the guided sampling.
|
||||
classifier_fn: A classifier function. Only used for the classifier guidance.
|
||||
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
||||
Returns:
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
classifier_kwargs = classifier_kwargs or {}
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
||||
For continuous-time DPMs, we just use `t_continuous`.
|
||||
"""
|
||||
if noise_schedule.schedule == 'discrete':
|
||||
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
|
||||
else:
|
||||
return t_continuous
|
||||
|
||||
def noise_pred_fn(x, t_continuous, cond=None):
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
if cond is None:
|
||||
output = model(x, t_input, None, **model_kwargs)
|
||||
else:
|
||||
output = model(x, t_input, cond, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
||||
elif model_type == "v":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
||||
elif model_type == "score":
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
dims = x.dim()
|
||||
return -expand_dims(sigma_t, dims) * output
|
||||
|
||||
def cond_grad_fn(x, t_input, condition):
|
||||
"""
|
||||
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
||||
"""
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
||||
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
||||
|
||||
def model_fn(x, t_continuous, condition, unconditional_condition):
|
||||
"""
|
||||
The noise prediction model function that is used for DPM-Solver.
|
||||
"""
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
if guidance_type == "uncond":
|
||||
return noise_pred_fn(x, t_continuous)
|
||||
elif guidance_type == "classifier":
|
||||
assert classifier_fn is not None
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
cond_grad = cond_grad_fn(x, t_input, condition)
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
noise = noise_pred_fn(x, t_continuous)
|
||||
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
||||
elif guidance_type == "classifier-free":
|
||||
if guidance_scale == 1. or unconditional_condition is None:
|
||||
return noise_pred_fn(x, t_continuous, cond=condition)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t_continuous] * 2)
|
||||
if isinstance(condition, dict):
|
||||
assert isinstance(unconditional_condition, dict)
|
||||
c_in = {}
|
||||
for k in condition:
|
||||
if isinstance(condition[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_condition[k][i],
|
||||
condition[k][i]]) for i in range(len(condition[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_condition[k],
|
||||
condition[k]])
|
||||
elif isinstance(condition, list):
|
||||
c_in = []
|
||||
assert isinstance(unconditional_condition, list)
|
||||
for i in range(len(condition)):
|
||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_condition, condition])
|
||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||
|
||||
assert model_type in ["noise", "x_start", "v"]
|
||||
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
||||
return model_fn
|
||||
|
||||
|
||||
class UniPC:
|
||||
def __init__(
|
||||
self,
|
||||
model_fn,
|
||||
noise_schedule,
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1',
|
||||
condition=None,
|
||||
unconditional_condition=None,
|
||||
before_sample=None,
|
||||
after_sample=None,
|
||||
after_update=None
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
We support both data_prediction and noise_prediction.
|
||||
"""
|
||||
self.model_fn_ = model_fn
|
||||
self.noise_schedule = noise_schedule
|
||||
self.variant = variant
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
self.condition = condition
|
||||
self.unconditional_condition = unconditional_condition
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
"""
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def model(self, x, t):
|
||||
cond = self.condition
|
||||
uncond = self.unconditional_condition
|
||||
if self.before_sample is not None:
|
||||
x, t, cond, uncond = self.before_sample(x, t, cond, uncond)
|
||||
res = self.model_fn_(x, t, cond, uncond)
|
||||
if self.after_sample is not None:
|
||||
x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)
|
||||
|
||||
if isinstance(res, tuple):
|
||||
# (None, pred_x0)
|
||||
res = res[1]
|
||||
|
||||
return res
|
||||
|
||||
def noise_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
Return the data prediction model (with thresholding).
|
||||
"""
|
||||
noise = self.noise_prediction_fn(x, t)
|
||||
dims = x.dim()
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||
if self.thresholding:
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
Convert the model to the noise prediction model or the data prediction model.
|
||||
"""
|
||||
if self.predict_x0:
|
||||
return self.data_prediction_fn(x, t)
|
||||
else:
|
||||
return self.noise_prediction_fn(x, t)
|
||||
|
||||
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
||||
"""Compute the intermediate time steps for sampling.
|
||||
"""
|
||||
if skip_type == 'logSNR':
|
||||
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||
elif skip_type == 'time_uniform':
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == 'time_quadratic':
|
||||
t_order = 2
|
||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||
"""
|
||||
if order == 3:
|
||||
K = steps // 3 + 1
|
||||
if steps % 3 == 0:
|
||||
orders = [3,] * (K - 2) + [2, 1]
|
||||
elif steps % 3 == 1:
|
||||
orders = [3,] * (K - 1) + [1]
|
||||
else:
|
||||
orders = [3,] * (K - 1) + [2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
K = steps // 2
|
||||
orders = [2,] * K
|
||||
else:
|
||||
K = steps // 2 + 1
|
||||
orders = [2,] * (K - 1) + [1]
|
||||
elif order == 1:
|
||||
K = steps
|
||||
orders = [1,] * steps
|
||||
else:
|
||||
raise ValueError("'order' must be '1' or '2' or '3'.")
|
||||
if skip_type == 'logSNR':
|
||||
# To reproduce the results in DPM-Solver paper
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
||||
else:
|
||||
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
|
||||
return timesteps_outer, orders
|
||||
|
||||
def denoise_to_zero_fn(self, x, s):
|
||||
"""
|
||||
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
|
||||
"""
|
||||
return self.data_prediction_fn(x, s)
|
||||
|
||||
def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
|
||||
if len(t.shape) == 0:
|
||||
t = t.view(-1)
|
||||
if 'bh' in self.variant:
|
||||
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
else:
|
||||
assert self.variant == 'vary_coeff'
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = (lambda_prev_i - lambda_prev_0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
K = len(rks)
|
||||
# build C matrix
|
||||
C = []
|
||||
|
||||
col = torch.ones_like(rks)
|
||||
for k in range(1, K + 1):
|
||||
C.append(col)
|
||||
col = col * rks / (k + 1)
|
||||
C = torch.stack(C, dim=1)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh)
|
||||
h_phi_ks = []
|
||||
factorial_k = 1
|
||||
h_phi_k = h_phi_1
|
||||
for k in range(1, K + 2):
|
||||
h_phi_ks.append(h_phi_k)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
||||
factorial_k *= (k + 1)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
sigma_t / sigma_prev_0 * x
|
||||
- alpha_t * h_phi_1 * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
else:
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
x_t_ = (
|
||||
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
|
||||
- (sigma_t * h_phi_1) * model_prev_0
|
||||
)
|
||||
# now predictor
|
||||
x_t = x_t_
|
||||
if len(D1s) > 0:
|
||||
# compute the residuals for predictor
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
||||
# now corrector
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_
|
||||
k = 0
|
||||
for k in range(K - 1):
|
||||
x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
||||
x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
|
||||
return x_t, model_t
|
||||
|
||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
dims = x.dim()
|
||||
|
||||
# first compute rks
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
||||
lambda_t = ns.marginal_lambda(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
||||
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h[0] if self.predict_x0 else h[0]
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.variant == 'bh1':
|
||||
B_h = hh
|
||||
elif self.variant == 'bh2':
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_corrector:
|
||||
#print('using corrector')
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
model_t = None
|
||||
if self.predict_x0:
|
||||
x_t_ = (
|
||||
expand_dims(sigma_t / sigma_prev_0, dims) * x
|
||||
- expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
|
||||
)
|
||||
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = (
|
||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||
)
|
||||
if x_t is None:
|
||||
if use_predictor:
|
||||
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
||||
|
||||
if use_corrector:
|
||||
model_t = self.model_fn(x_t, t)
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = (model_t - model_prev_0)
|
||||
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
return x_t, model_t
|
||||
|
||||
|
||||
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||
atol=0.0078, rtol=0.05, corrector=False,
|
||||
):
|
||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
if method == 'multistep':
|
||||
assert steps >= order, "UniPC order must be < sampling steps"
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
#print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
with torch.no_grad():
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
with tqdm.tqdm(total=steps) as pbar:
|
||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||
for init_order in range(1, order):
|
||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
pbar.update()
|
||||
|
||||
for step in range(order, steps + 1):
|
||||
vec_t = timesteps[step].expand(x.shape[0])
|
||||
if lower_order_final:
|
||||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
t_prev_list[-1] = vec_t
|
||||
# We do not need to evaluate the final model value.
|
||||
if step < steps:
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
model_prev_list[-1] = model_x
|
||||
pbar.update()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if denoise_to_zero:
|
||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||
return x
|
||||
|
||||
|
||||
#############################################################
|
||||
# other utility functions
|
||||
#############################################################
|
||||
|
||||
def interpolate_fn(x, xp, yp):
|
||||
"""
|
||||
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
||||
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
||||
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
||||
|
||||
Args:
|
||||
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
||||
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
||||
yp: PyTorch tensor with shape [C, K].
|
||||
Returns:
|
||||
The function values f(x), with shape [N, C].
|
||||
"""
|
||||
N, K = x.shape[0], xp.shape[1]
|
||||
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
||||
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
||||
x_idx = torch.argmin(x_indices, dim=2)
|
||||
cand_start_idx = x_idx - 1
|
||||
start_idx = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(1, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||
start_idx2 = torch.where(
|
||||
torch.eq(x_idx, 0),
|
||||
torch.tensor(0, device=x.device),
|
||||
torch.where(
|
||||
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
||||
),
|
||||
)
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
"""
|
||||
Expand the tensor `v` to the dim `dims`.
|
||||
|
||||
Args:
|
||||
`v`: a PyTorch tensor with shape [N].
|
||||
`dim`: a `int`.
|
||||
Returns:
|
||||
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||
"""
|
||||
return v[(...,) + (None,)*(dims - 1)]
|
||||
622
modules/models/sd3/mmdit.py
Executable file
622
modules/models/sd3/mmdit.py
Executable file
@@ -0,0 +1,622 @@
|
||||
### This file contains impls for MM-DiT, the core model component of SD3
|
||||
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from modules.models.sd3.other_impls import attention, Mlp
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""Embeds scalar timesteps into vector representations."""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
"""Embeds a flat vector of dimension input_dim"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class QkvLinear(torch.nn.Linear):
|
||||
pass
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = nn.Identity()
|
||||
self.ln_k = nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
(q, k, v) = self.pre_attention(x)
|
||||
x = attention(q, k, v, self.num_heads)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(nn.Module):
|
||||
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert x is not None, "pre_attention called with None input"
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(context, x, context_block, x_block, c):
|
||||
assert context is not None, "block_mixing called with None context"
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
q, k, v = tuple(o)
|
||||
|
||||
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||
context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = (
|
||||
nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
if (total_out_channels is None)
|
||||
else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
"""Diffusion model with a Transformer backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches = None,
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
dtype = None,
|
||||
device = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.context_embedder = nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
|
||||
|
||||
def cropped_pos_embed(self, hw):
|
||||
assert self.pos_embed_max_size is not None
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for block in self.joint_blocks:
|
||||
context, x = block(context, x, c=c_mod)
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x
|
||||
510
modules/models/sd3/other_impls.py
Executable file
510
modules/models/sd3/other_impls.py
Executable file
@@ -0,0 +1,510 @@
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from modules import sd_hijack
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### Core/Utility
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class AutocastLinear(nn.Linear):
|
||||
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
|
||||
|
||||
This is different from torch.autocast in a way that float16 layer processing float32 input
|
||||
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
|
||||
if you do it in full float16 (returning almost all zeros in the final output).
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### CLIP
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
out = attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
ACTIVATIONS = {
|
||||
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
|
||||
super().__init__()
|
||||
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
|
||||
def tokenize_with_weights(self, text:str):
|
||||
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
||||
if self.pad_with_end:
|
||||
pad_token = self.end_token
|
||||
else:
|
||||
pad_token = 0
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0))
|
||||
to_tokenize = text.replace("\n", " ").split(' ')
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||
batch.append((self.end_token, 1.0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||
return [batch]
|
||||
|
||||
|
||||
class SDXLClipGTokenizer(SDTokenizer):
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(pad_with_end=False, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self):
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
self.t5xxl = T5XXLTokenizer()
|
||||
|
||||
def tokenize_with_weights(self, text:str):
|
||||
out = {}
|
||||
out["g"] = self.clip_g.tokenize_with_weights(text)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text)
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
|
||||
return out
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = [a[0] for a in token_weight_pairs[0]]
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled
|
||||
output = [out[0:1]]
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
|
||||
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
self.max_length = max_length
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)
|
||||
outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
z = outputs[1]
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class SDXLClipG(SDClipModel):
|
||||
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
|
||||
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
|
||||
class T5XXLModel(SDClipModel):
|
||||
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
|
||||
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
def __init__(self):
|
||||
super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
forwarded_states = self.layer_norm(x)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
x += forwarded_states
|
||||
return x
|
||||
|
||||
|
||||
class T5Attention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_bias = None
|
||||
if relative_attention_bias:
|
||||
self.relative_attention_num_buckets = 32
|
||||
self.relative_attention_max_distance = 128
|
||||
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = relative_position < max_exact
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
relative_position_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, query_length, key_length, device):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=True,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
|
||||
if self.relative_attention_bias is not None:
|
||||
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
|
||||
if past_bias is not None:
|
||||
mask = past_bias
|
||||
else:
|
||||
mask = None
|
||||
|
||||
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
|
||||
|
||||
return self.o(out), past_bias
|
||||
|
||||
|
||||
class T5LayerSelfAttention(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
|
||||
x += output
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
|
||||
|
||||
def forward(self, x, past_bias=None):
|
||||
x, past_bias = self.layer[0](x, past_bias)
|
||||
x = self.layer[-1](x)
|
||||
return x, past_bias
|
||||
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
|
||||
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
|
||||
past_bias = None
|
||||
for i, layer in enumerate(self.block):
|
||||
x, past_bias = layer(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.final_layer_norm(intermediate)
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class T5(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.encoder.embed_tokens = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.encoder(*args, **kwargs)
|
||||
222
modules/models/sd3/sd3_cond.py
Executable file
222
modules/models/sd3/sd3_cond.py
Executable file
@@ -0,0 +1,222 @@
|
||||
import os
|
||||
import safetensors
|
||||
import torch
|
||||
import typing
|
||||
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
|
||||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||
|
||||
|
||||
class SafetensorsMapping(typing.Mapping):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for key in self.file.keys():
|
||||
yield key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.file.get_tensor(key)
|
||||
|
||||
|
||||
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"textual_inversion_key": "clip_g",
|
||||
}
|
||||
|
||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
|
||||
def __init__(self, clip_l, clip_g):
|
||||
super().__init__()
|
||||
|
||||
self.clip_l = clip_l
|
||||
self.clip_g = clip_g
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.id_start = empty[0]
|
||||
self.id_end = empty[1]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
self.return_pooled = True
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
tokens_g = tokens.clone()
|
||||
|
||||
for batch_pos in range(tokens_g.shape[0]):
|
||||
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
|
||||
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
|
||||
|
||||
l_out, l_pooled = self.clip_l(tokens)
|
||||
g_out, g_pooled = self.clip_g(tokens_g)
|
||||
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
|
||||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
lg_out.pooled = vector_out
|
||||
return lg_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||
|
||||
|
||||
class Sd3T5(torch.nn.Module):
|
||||
def __init__(self, t5xxl):
|
||||
super().__init__()
|
||||
|
||||
self.t5xxl = t5xxl
|
||||
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
|
||||
|
||||
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
|
||||
self.id_end = empty[0]
|
||||
self.id_pad = empty[1]
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def tokenize_line(self, line, *, target_token_count=None):
|
||||
if shared.opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
|
||||
for text_tokens, (text, weight) in zip(tokenized, parsed):
|
||||
if text == 'BREAK' and weight == -1:
|
||||
continue
|
||||
|
||||
tokens += text_tokens
|
||||
multipliers += [weight] * len(text_tokens)
|
||||
|
||||
tokens += [self.id_end]
|
||||
multipliers += [1.0]
|
||||
|
||||
if target_token_count is not None:
|
||||
if len(tokens) < target_token_count:
|
||||
tokens += [self.id_pad] * (target_token_count - len(tokens))
|
||||
multipliers += [1.0] * (target_token_count - len(tokens))
|
||||
else:
|
||||
tokens = tokens[0:target_token_count]
|
||||
multipliers = multipliers[0:target_token_count]
|
||||
|
||||
return tokens, multipliers
|
||||
|
||||
def forward(self, texts, *, token_count):
|
||||
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||
|
||||
tokens_batch = []
|
||||
|
||||
for text in texts:
|
||||
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
|
||||
tokens_batch.append(tokens)
|
||||
|
||||
t5_out, t5_pooled = self.t5xxl(tokens_batch)
|
||||
|
||||
return t5_out
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
|
||||
|
||||
|
||||
class SD3Cond(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.sd3_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
lg_out, vector_out = self.model_lg(prompts)
|
||||
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
'crossattn': lgt_out,
|
||||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
|
||||
|
||||
def tokenize(self, texts):
|
||||
return self.model_lg.tokenize(texts)
|
||||
|
||||
def medvram_modules(self):
|
||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||
|
||||
def get_token_count(self, text):
|
||||
_, token_count = self.model_lg.process_texts([text])
|
||||
|
||||
return token_count
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
return self.model_lg.get_target_prompt_token_count(token_count)
|
||||
374
modules/models/sd3/sd3_impls.py
Executable file
374
modules/models/sd3/sd3_impls.py
Executable file
@@ -0,0 +1,374 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import torch
|
||||
import math
|
||||
import einops
|
||||
from modules.models.sd3.mmdit import MMDiT
|
||||
from PIL import Image
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### MMDiT Model Wrapping
|
||||
#################################################################################################
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||
def __init__(self, shift=1.0):
|
||||
super().__init__()
|
||||
self.shift = shift
|
||||
timesteps = 1000
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
|
||||
super().__init__()
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
"in_features": context_shape[1],
|
||||
"out_features": context_shape[0]
|
||||
}
|
||||
}
|
||||
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
|
||||
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
|
||||
self.depth = depth
|
||||
|
||||
def apply_model(self, x, sigma, c_crossattn=None, y=None):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""Helper for applying CFG Scaling to diffusion outputs"""
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, x, timestep, cond, uncond, cond_scale):
|
||||
# Run cond and uncond in a batch together
|
||||
batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))
|
||||
# Then split and apply CFG Scaling
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
scaled = neg_out + (pos_out - neg_out) * cond_scale
|
||||
return scaled
|
||||
|
||||
|
||||
class SD3LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.5305
|
||||
self.shift_factor = 0.0609
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
"""Quick RGB approximate preview of sd3 latents"""
|
||||
factors = torch.tensor([
|
||||
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]
|
||||
], device="cpu")
|
||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
|
||||
|
||||
latents_ubyte = (((latent_image + 1) / 2)
|
||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
.byte()).cpu()
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### K-Diffusion Sampling
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def sample_euler(model, x, sigmas, extra_args=None):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma_hat = sigmas[i]
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################################
|
||||
### VAE
|
||||
#################################################################################################
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.nin_shortcut = None
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = x
|
||||
hidden = self.norm1(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv1(hidden)
|
||||
hidden = self.norm2(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv2(hidden)
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class AttnBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.norm(x)
|
||||
q = self.q(hidden)
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
return x + hidden
|
||||
|
||||
|
||||
class Downsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(torch.nn.Module):
|
||||
def __init__(self, in_channels, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class VAEEncoder(torch.nn.Module):
|
||||
def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = torch.nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = torch.nn.ModuleList()
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for _ in range(num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, dtype=dtype, device=device)
|
||||
self.down.append(down)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class VAEDecoder(torch.nn.Module):
|
||||
def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
# middle
|
||||
self.mid = torch.nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
|
||||
# upsampling
|
||||
self.up = torch.nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
up.block = block
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, dtype=dtype, device=device)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
# end
|
||||
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
|
||||
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
hidden = self.conv_in(z)
|
||||
# middle
|
||||
hidden = self.mid.block_1(hidden)
|
||||
hidden = self.mid.attn_1(hidden)
|
||||
hidden = self.mid.block_2(hidden)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
hidden = self.up[i_level].block[i_block](hidden)
|
||||
if i_level != 0:
|
||||
hidden = self.up[i_level].upsample(hidden)
|
||||
# end
|
||||
hidden = self.norm_out(hidden)
|
||||
hidden = self.swish(hidden)
|
||||
hidden = self.conv_out(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
class SDVAE(torch.nn.Module):
|
||||
def __init__(self, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(dtype=dtype, device=device)
|
||||
self.decoder = VAEDecoder(dtype=dtype, device=device)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def decode(self, latent):
|
||||
return self.decoder(latent)
|
||||
|
||||
@torch.autocast("cuda", dtype=torch.float16)
|
||||
def encode(self, image):
|
||||
hidden = self.encoder(image)
|
||||
mean, logvar = torch.chunk(hidden, 2, dim=1)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
96
modules/models/sd3/sd3_model.py
Executable file
96
modules/models/sd3/sd3_model.py
Executable file
@@ -0,0 +1,96 @@
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
import k_diffusion
|
||||
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
|
||||
from modules.models.sd3.sd3_cond import SD3Cond
|
||||
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||
def __init__(self, inner_model, sigmas):
|
||||
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
|
||||
self.inner_model = inner_model
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
return self.inner_model.apply_model(input, sigma, **kwargs)
|
||||
|
||||
|
||||
class SD3Inferencer(torch.nn.Module):
|
||||
def __init__(self, state_dict, shift=3, use_ema=False):
|
||||
super().__init__()
|
||||
|
||||
self.shift = shift
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
|
||||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
|
||||
self.first_stage_model.dtype = self.model.diffusion_model.dtype
|
||||
|
||||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
||||
|
||||
self.text_encoders = SD3Cond()
|
||||
self.cond_stage_key = 'txt'
|
||||
|
||||
self.parameterization = "eps"
|
||||
self.model.conditioning_key = "crossattn"
|
||||
|
||||
self.latent_format = SD3LatentFormat()
|
||||
self.latent_channels = 16
|
||||
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.text_encoders
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
self.cond_stage_model.before_load_weights(state_dict)
|
||||
|
||||
def ema_scope(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def get_learned_conditioning(self, batch: list[str]):
|
||||
return self.cond_stage_model(batch)
|
||||
|
||||
def apply_model(self, x, t, cond):
|
||||
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||
|
||||
def decode_first_stage(self, latent):
|
||||
latent = self.latent_format.process_out(latent)
|
||||
return self.first_stage_model.decode(latent)
|
||||
|
||||
def encode_first_stage(self, image):
|
||||
latent = self.first_stage_model.encode(image)
|
||||
return self.latent_format.process_in(latent)
|
||||
|
||||
def get_first_stage_encoding(self, x):
|
||||
return x
|
||||
|
||||
def create_denoiser(self):
|
||||
return SD3Denoiser(self, self.model.model_sampling.sigmas)
|
||||
|
||||
def medvram_fields(self):
|
||||
return [
|
||||
(self, 'first_stage_model'),
|
||||
(self, 'text_encoders'),
|
||||
(self, 'model'),
|
||||
]
|
||||
|
||||
def add_noise_to_latent(self, x, noise, amount):
|
||||
return x * (1 - amount) + noise * amount
|
||||
|
||||
def fix_dimensions(self, width, height):
|
||||
return width // 16 * 16, height // 16 * 16
|
||||
|
||||
def diffusers_weight_mapping(self):
|
||||
for i in range(self.model.depth):
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
|
||||
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
|
||||
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
|
||||
30
modules/ngrok.py
Executable file
30
modules/ngrok.py
Executable file
@@ -0,0 +1,30 @@
|
||||
import ngrok
|
||||
|
||||
# Connect to ngrok for ingress
|
||||
def connect(token, port, options):
|
||||
account = None
|
||||
if token is None:
|
||||
token = 'None'
|
||||
else:
|
||||
if ':' in token:
|
||||
# token = authtoken:username:password
|
||||
token, username, password = token.split(':', 2)
|
||||
account = f"{username}:{password}"
|
||||
|
||||
# For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
||||
if not options.get('authtoken_from_env'):
|
||||
options['authtoken'] = token
|
||||
if account:
|
||||
options['basic_auth'] = account
|
||||
if not options.get('session_metadata'):
|
||||
options['session_metadata'] = 'stable-diffusion-webui'
|
||||
|
||||
|
||||
try:
|
||||
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
||||
except Exception as e:
|
||||
print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
|
||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||
else:
|
||||
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
|
||||
'You can use this link after the launch is complete.')
|
||||
31
modules/npu_specific.py
Executable file
31
modules/npu_specific.py
Executable file
@@ -0,0 +1,31 @@
|
||||
# import importlib
|
||||
# import torch
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# def check_for_npu():
|
||||
# if importlib.util.find_spec("torch_npu") is None:
|
||||
# return False
|
||||
# import torch_npu
|
||||
#
|
||||
# try:
|
||||
# # Will raise a RuntimeError if no NPU is found
|
||||
# _ = torch_npu.npu.device_count()
|
||||
# return torch.npu.is_available()
|
||||
# except RuntimeError:
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# def get_npu_device_string():
|
||||
# if shared.cmd_opts.device_id is not None:
|
||||
# return f"npu:{shared.cmd_opts.device_id}"
|
||||
# return "npu:0"
|
||||
#
|
||||
#
|
||||
# def torch_npu_gc():
|
||||
# with torch.npu.device(get_npu_device_string()):
|
||||
# torch.npu.empty_cache()
|
||||
#
|
||||
#
|
||||
# has_npu = check_for_npu()
|
||||
336
modules/options.py
Executable file
336
modules/options.py
Executable file
@@ -0,0 +1,336 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
self.section = section
|
||||
self.category_id = category_id
|
||||
self.refresh = refresh
|
||||
self.do_not_save = False
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
|
||||
self.comment_after = comment_after
|
||||
"""HTML text that will be added before label in UI"""
|
||||
|
||||
self.infotext = infotext
|
||||
|
||||
self.restrict_api = restrict_api
|
||||
"""If True, the setting will not be accessible via API"""
|
||||
|
||||
def link(self, label, url):
|
||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def js(self, label, js_func):
|
||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def info(self, info):
|
||||
self.comment_after += f"<span class='info'>({info})</span>"
|
||||
return self
|
||||
|
||||
def html(self, html):
|
||||
self.comment_after += html
|
||||
return self
|
||||
|
||||
def needs_restart(self):
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
def needs_reload_ui(self):
|
||||
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
||||
return self
|
||||
|
||||
|
||||
class OptionHTML(OptionInfo):
|
||||
def __init__(self, text):
|
||||
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
||||
|
||||
self.do_not_save = True
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for v in options_dict.values():
|
||||
if len(section_identifier) == 2:
|
||||
v.section = section_identifier
|
||||
elif len(section_identifier) == 3:
|
||||
v.section = section_identifier[0:2]
|
||||
v.category_id = section_identifier[2]
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
|
||||
|
||||
|
||||
class Options:
|
||||
typemap = {int: float}
|
||||
|
||||
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||
self.data_labels = data_labels
|
||||
self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
|
||||
self.restricted_opts = restricted_opts
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in options_builtin_fields:
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
if self.data is not None:
|
||||
if key in self.data or key in self.data_labels:
|
||||
|
||||
# Check that settings aren't globally frozen
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
# Get the info related to the setting being changed
|
||||
info = self.data_labels.get(key, None)
|
||||
if info.do_not_save:
|
||||
return
|
||||
|
||||
# Restrict component arguments
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted")
|
||||
|
||||
# Check that this section isn't frozen
|
||||
if cmd_opts.freeze_settings_in_sections is not None:
|
||||
frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
|
||||
section_key = info.section[0]
|
||||
section_name = info.section[1]
|
||||
assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
|
||||
|
||||
# Check that this section of the settings isn't frozen
|
||||
if cmd_opts.freeze_specific_settings is not None:
|
||||
frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
|
||||
assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
|
||||
|
||||
# Check shorthand option which disables editing options in "saving-paths"
|
||||
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in options_builtin_fields:
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
if self.data is not None:
|
||||
if item in self.data:
|
||||
return self.data[item]
|
||||
|
||||
if item in self.data_labels:
|
||||
return self.data_labels[item].default
|
||||
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
def set(self, key, value, is_api=False, run_callbacks=True):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
option = self.data_labels[key]
|
||||
if option.do_not_save:
|
||||
return False
|
||||
|
||||
if is_api and option.restrict_api:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if run_callbacks and option.onchange is not None:
|
||||
try:
|
||||
option.onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_default(self, key):
|
||||
"""returns the default value for the key"""
|
||||
|
||||
data_label = self.data_labels.get(key)
|
||||
if data_label is None:
|
||||
return None
|
||||
|
||||
return data_label.default
|
||||
|
||||
def save(self, filename):
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
json.dump(self.data, file, indent=4, ensure_ascii=False)
|
||||
|
||||
def same_type(self, x, y):
|
||||
if x is None or y is None:
|
||||
return True
|
||||
|
||||
type_x = self.typemap.get(type(x), type(x))
|
||||
type_y = self.typemap.get(type(y), type(y))
|
||||
|
||||
return type_x == type_y
|
||||
|
||||
def load(self, filename):
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
|
||||
self.data = {}
|
||||
# 1.6.0 VAE defaults
|
||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||
|
||||
# 1.1.1 quicksettings list migration
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
||||
# 1.4.0 ui_reorder
|
||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||
|
||||
bad_settings = 0
|
||||
for k, v in self.data.items():
|
||||
info = self.data_labels.get(k, None)
|
||||
if info is not None and not self.same_type(info.default, v):
|
||||
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
||||
bad_settings += 1
|
||||
|
||||
if bad_settings > 0:
|
||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||
|
||||
def onchange(self, key, func, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
if call:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
|
||||
item_categories = {}
|
||||
for item in self.data_labels.values():
|
||||
if item.section[0] is None:
|
||||
continue
|
||||
|
||||
category = categories.mapping.get(item.category_id)
|
||||
category = "Uncategorized" if category is None else category.label
|
||||
if category not in item_categories:
|
||||
item_categories[category] = item.section[1]
|
||||
|
||||
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
|
||||
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
|
||||
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
if key not in self.data and not info.do_not_save:
|
||||
self.data[key] = info.default
|
||||
|
||||
def reorder(self):
|
||||
"""Reorder settings so that:
|
||||
- all items related to section always go together
|
||||
- all sections belonging to a category go together
|
||||
- sections inside a category are ordered alphabetically
|
||||
- categories are ordered by creation order
|
||||
|
||||
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
|
||||
|
||||
This function also changes items' category_id so that all items belonging to a section have the same category_id.
|
||||
"""
|
||||
|
||||
category_ids = {}
|
||||
section_categories = {}
|
||||
|
||||
settings_items = self.data_labels.items()
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_categories:
|
||||
section_categories[item.section] = item.category_id
|
||||
|
||||
for _, item in settings_items:
|
||||
item.category_id = section_categories.get(item.section)
|
||||
|
||||
for category_id in categories.mapping:
|
||||
if category_id not in category_ids:
|
||||
category_ids[category_id] = len(category_ids)
|
||||
|
||||
def sort_key(x):
|
||||
item: OptionInfo = x[1]
|
||||
category_order = category_ids.get(item.category_id, len(category_ids))
|
||||
section_order = item.section[1]
|
||||
|
||||
return category_order, section_order
|
||||
|
||||
self.data_labels = dict(sorted(settings_items, key=sort_key))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
default_value = self.data_labels[key].default
|
||||
if default_value is None:
|
||||
default_value = getattr(self, key, None)
|
||||
if default_value is None:
|
||||
return None
|
||||
|
||||
expected_type = type(default_value)
|
||||
if expected_type == bool and value == "False":
|
||||
value = False
|
||||
else:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionsCategory:
|
||||
id: str
|
||||
label: str
|
||||
|
||||
class OptionsCategories:
|
||||
def __init__(self):
|
||||
self.mapping = {}
|
||||
|
||||
def register_category(self, category_id, label):
|
||||
if category_id in self.mapping:
|
||||
return category_id
|
||||
|
||||
self.mapping[category_id] = OptionsCategory(category_id, label)
|
||||
|
||||
|
||||
categories = OptionsCategories()
|
||||
64
modules/patches.py
Executable file
64
modules/patches.py
Executable file
@@ -0,0 +1,64 @@
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def patch(key, obj, field, replacement):
|
||||
"""Replaces a function in a module or a class.
|
||||
|
||||
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
replacement: the new function
|
||||
|
||||
Returns:
|
||||
the original function
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
if patch_key in originals[key]:
|
||||
raise RuntimeError(f"patch for {field} is already applied")
|
||||
|
||||
original_func = getattr(obj, field)
|
||||
originals[key][patch_key] = original_func
|
||||
|
||||
setattr(obj, field, replacement)
|
||||
|
||||
return original_func
|
||||
|
||||
|
||||
def undo(key, obj, field):
|
||||
"""Undoes the peplacement by the patch().
|
||||
|
||||
If the function is not replaced, raises an exception.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
|
||||
Returns:
|
||||
Always None
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
|
||||
if patch_key not in originals[key]:
|
||||
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||
|
||||
original_func = originals[key].pop(patch_key)
|
||||
setattr(obj, field, original_func)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def original(key, obj, field):
|
||||
"""Returns the original function for the patch created by the patch() function"""
|
||||
patch_key = (obj, field)
|
||||
|
||||
return originals[key].get(patch_key, None)
|
||||
|
||||
|
||||
originals = defaultdict(dict)
|
||||
29
modules/paths.py
Executable file
29
modules/paths.py
Executable file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||
|
||||
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
sd_path = os.path.dirname(__file__)
|
||||
|
||||
path_dirs = [
|
||||
(os.path.join(sd_path, '../repositories/BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../packages_3rdparty'), 'gguf/quants.py', 'packages_3rdparty', []),
|
||||
# (os.path.join(sd_path, '../repositories/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(os.path.join(sd_path, '../repositories/huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
|
||||
for d, must_exist, what, options in path_dirs:
|
||||
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
||||
if not os.path.exists(must_exist_path):
|
||||
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
|
||||
else:
|
||||
d = os.path.abspath(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
38
modules/paths_internal.py
Executable file
38
modules/paths_internal.py
Executable file
@@ -0,0 +1,38 @@
|
||||
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
normalized_filepath = lambda filepath: str(Path(filepath).absolute())
|
||||
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
cwd = os.getcwd()
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||
parser_pre.add_argument("--models-dir", type=str, default=None, help="base path where models are stored; overrides --data-dir", )
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
|
||||
models_path = cmd_opts_pre.models_dir if cmd_opts_pre.models_dir else os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
config_states_dir = os.path.join(script_path, "config_states")
|
||||
default_output_dir = os.path.join(data_path, "outputs")
|
||||
|
||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||
165
modules/postprocessing.py
Executable file
165
modules/postprocessing.py
Executable file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin(job="extras")
|
||||
|
||||
outputs = []
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image["composite"]
|
||||
|
||||
def get_images(extras_mode, image, image_folder, input_dir):
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
if isinstance(img, Image.Image):
|
||||
image = images.fix_image(img)
|
||||
fn = ''
|
||||
else:
|
||||
image = images.read(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.name)[0]
|
||||
yield image, fn
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for filename in image_list:
|
||||
yield filename, filename
|
||||
else:
|
||||
assert image, 'image not selected'
|
||||
yield image, None
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
infotext = ''
|
||||
|
||||
data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
|
||||
shared.state.job_count = len(data_to_process)
|
||||
|
||||
for image_placeholder, name in data_to_process:
|
||||
image_data: Image.Image
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.textinfo = name
|
||||
shared.state.skipped = False
|
||||
|
||||
if shared.state.interrupted or shared.state.stopping_generation:
|
||||
break
|
||||
|
||||
if isinstance(image_placeholder, str):
|
||||
try:
|
||||
image_data = images.read(image_placeholder)
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
image_data = image_placeholder
|
||||
|
||||
image_data = image_data if image_data.mode in ("RGBA", "RGB") else image_data.convert("RGB")
|
||||
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
|
||||
initial_pp = scripts_postprocessing.PostprocessedImage(image_data)
|
||||
|
||||
scripts.scripts_postproc.run(initial_pp, args)
|
||||
|
||||
if shared.state.skipped:
|
||||
continue
|
||||
|
||||
used_suffixes = {}
|
||||
for pp in [initial_pp, *initial_pp.extra_images]:
|
||||
suffix = pp.get_suffix(used_suffixes)
|
||||
|
||||
if opts.use_original_name_batch and name is not None:
|
||||
basename = os.path.splitext(os.path.basename(name))[0]
|
||||
forced_filename = basename + suffix
|
||||
else:
|
||||
basename = ''
|
||||
forced_filename = None
|
||||
|
||||
infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pp.image.info = existing_pnginfo
|
||||
|
||||
shared.state.assign_current_image(pp.image)
|
||||
|
||||
if save_output:
|
||||
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="postprocessing", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
||||
|
||||
if pp.caption:
|
||||
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
||||
existing_caption = ""
|
||||
try:
|
||||
with open(caption_filename, encoding="utf8") as file:
|
||||
existing_caption = file.read().strip()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
action = shared.opts.postprocessing_existing_caption_action
|
||||
if action == 'Prepend' and existing_caption:
|
||||
caption = f"{existing_caption} {pp.caption}"
|
||||
elif action == 'Append' and existing_caption:
|
||||
caption = f"{pp.caption} {existing_caption}"
|
||||
elif action == 'Keep' and existing_caption:
|
||||
caption = existing_caption
|
||||
else:
|
||||
caption = pp.caption
|
||||
|
||||
caption = caption.strip()
|
||||
if caption:
|
||||
with open(caption_filename, "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
|
||||
devices.torch_gc()
|
||||
shared.state.end()
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
||||
|
||||
def run_postprocessing_webui(id_task, *args, **kwargs):
|
||||
return run_postprocessing(*args, **kwargs)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True, max_side_length: int = 0):
|
||||
"""old handler for API"""
|
||||
|
||||
args = scripts.scripts_postproc.create_args_for_run({
|
||||
"Upscale": {
|
||||
"upscale_enabled": True,
|
||||
"upscale_mode": resize_mode,
|
||||
"upscale_by": upscaling_resize,
|
||||
"max_side_length": max_side_length,
|
||||
"upscale_to_width": upscaling_resize_w,
|
||||
"upscale_to_height": upscaling_resize_h,
|
||||
"upscale_crop": upscaling_crop,
|
||||
"upscaler_1_name": extras_upscaler_1,
|
||||
"upscaler_2_name": extras_upscaler_2,
|
||||
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||
},
|
||||
"GFPGAN": {
|
||||
"enable": True,
|
||||
"gfpgan_visibility": gfpgan_visibility,
|
||||
},
|
||||
"CodeFormer": {
|
||||
"enable": True,
|
||||
"codeformer_visibility": codeformer_visibility,
|
||||
"codeformer_weight": codeformer_weight,
|
||||
},
|
||||
})
|
||||
|
||||
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
|
||||
1883
modules/processing.py
Executable file
1883
modules/processing.py
Executable file
File diff suppressed because it is too large
Load Diff
13
modules/processing_scripts/comments.py
Executable file
13
modules/processing_scripts/comments.py
Executable file
@@ -0,0 +1,13 @@
|
||||
from modules import shared
|
||||
import re
|
||||
|
||||
def strip_comments(text):
|
||||
if shared.opts.enable_prompt_comments:
|
||||
text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # whole line comment
|
||||
text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment
|
||||
|
||||
return text
|
||||
|
||||
shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
"enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
|
||||
}))
|
||||
49
modules/processing_scripts/refiner.py
Executable file
49
modules/processing_scripts/refiner.py
Executable file
@@ -0,0 +1,49 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_models
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
|
||||
class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||
section = "accordions"
|
||||
create_group = False
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def title(self):
|
||||
return "Refiner"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||
with gr.Row():
|
||||
refiner_checkpoint = gr.Dropdown(label='Checkpoint', info='(use model of same architecture)', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles(use_short=True)], value='', tooltip="switch to another model in the middle of generation")
|
||||
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles(use_short=True)}, self.elem_id("checkpoint_refresh"))
|
||||
|
||||
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||
|
||||
def lookup_checkpoint(title):
|
||||
info = sd_models.get_closet_checkpoint_match(title)
|
||||
return None if info is None else info.short_title
|
||||
|
||||
self.infotext_fields = [
|
||||
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||
]
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
||||
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
p.refiner_checkpoint = None
|
||||
p.refiner_switch_at = None
|
||||
else:
|
||||
p.refiner_checkpoint = refiner_checkpoint
|
||||
p.refiner_switch_at = refiner_switch_at
|
||||
64
modules/processing_scripts/sampler.py
Executable file
64
modules/processing_scripts/sampler.py
Executable file
@@ -0,0 +1,64 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_samplers, sd_schedulers, shared
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.ui_components import FormRow, FormGroup
|
||||
|
||||
|
||||
class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||
section = "sampler"
|
||||
|
||||
def __init__(self):
|
||||
self.steps = None
|
||||
self.sampler_name = None
|
||||
self.scheduler = None
|
||||
|
||||
def title(self):
|
||||
return "Sampler"
|
||||
|
||||
def ui(self, is_img2img):
|
||||
sampler_names = [x.name for x in sd_samplers.visible_samplers()]
|
||||
scheduler_names = [x.label for x in sd_schedulers.schedulers]
|
||||
|
||||
if shared.opts.samplers_in_dropdown:
|
||||
with FormRow(elem_id=f"sampler_selection_{self.tabname}"):
|
||||
self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||
else:
|
||||
with FormGroup(elem_id=f"sampler_selection_{self.tabname}"):
|
||||
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||
self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||
|
||||
self.infotext_fields = [
|
||||
PasteField(self.steps, "Steps", api="steps"),
|
||||
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
|
||||
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
|
||||
]
|
||||
|
||||
shared.options_templates.update(shared.options_section(('ui_sd', "UI defaults 'sd'", "ui"), {
|
||||
"sd_t2i_sampler": shared.OptionInfo('Euler a', "txt2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"sd_t2i_scheduler": shared.OptionInfo('Automatic', "txt2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
"sd_i2i_sampler": shared.OptionInfo('Euler a', "img2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"sd_i2i_scheduler": shared.OptionInfo('Automatic', "img2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
}))
|
||||
shared.options_templates.update(shared.options_section(('ui_xl', "UI defaults 'xl'", "ui"), {
|
||||
"xl_t2i_sampler": shared.OptionInfo('DPM++ 2M SDE', "txt2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"xl_t2i_scheduler": shared.OptionInfo('Karras', "txt2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
"xl_i2i_sampler": shared.OptionInfo('DPM++ 2M SDE', "img2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"xl_i2i_scheduler": shared.OptionInfo('Karras', "img2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
}))
|
||||
shared.options_templates.update(shared.options_section(('ui_flux', "UI defaults 'flux'", "ui"), {
|
||||
"flux_t2i_sampler": shared.OptionInfo('Euler', "txt2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"flux_t2i_scheduler": shared.OptionInfo('Simple', "txt2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
"flux_i2i_sampler": shared.OptionInfo('Euler', "img2img sampler", gr.Dropdown, {"choices": sampler_names}),
|
||||
"flux_i2i_scheduler": shared.OptionInfo('Simple', "img2img scheduler", gr.Dropdown, {"choices": scheduler_names}),
|
||||
}))
|
||||
|
||||
return self.steps, self.sampler_name, self.scheduler
|
||||
|
||||
def setup(self, p, steps, sampler_name, scheduler):
|
||||
p.steps = steps
|
||||
p.sampler_name = sampler_name
|
||||
p.scheduler = scheduler
|
||||
105
modules/processing_scripts/seed.py
Executable file
105
modules/processing_scripts/seed.py
Executable file
@@ -0,0 +1,105 @@
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui, errors
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui_components import ToolButton
|
||||
from modules import infotext_utils
|
||||
|
||||
|
||||
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
section = "seed"
|
||||
create_group = False
|
||||
|
||||
def __init__(self):
|
||||
self.seed = None
|
||||
self.reuse_seed = None
|
||||
self.reuse_subseed = None
|
||||
|
||||
def title(self):
|
||||
return "Seed"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with gr.Row(elem_id=self.elem_id("seed_row")):
|
||||
if cmd_opts.use_textbox_seed:
|
||||
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
|
||||
else:
|
||||
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||
|
||||
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
|
||||
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
|
||||
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False, scale=0, min_width=60)
|
||||
|
||||
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
|
||||
with gr.Row(elem_id=self.elem_id("subseed_row")):
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
|
||||
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
|
||||
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
|
||||
|
||||
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
|
||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
|
||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
|
||||
|
||||
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||
|
||||
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||
|
||||
self.infotext_fields = [
|
||||
PasteField(self.seed, "Seed", api="seed"),
|
||||
PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
PasteField(subseed, "Variation seed", api="subseed"),
|
||||
PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
|
||||
PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
|
||||
PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
|
||||
]
|
||||
|
||||
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
|
||||
|
||||
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
|
||||
|
||||
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
|
||||
p.seed = seed
|
||||
|
||||
if seed_checkbox and subseed_strength > 0:
|
||||
p.subseed = subseed
|
||||
p.subseed_strength = subseed_strength
|
||||
|
||||
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
|
||||
p.seed_resize_from_w = seed_resize_from_w
|
||||
p.seed_resize_from_h = seed_resize_from_h
|
||||
|
||||
|
||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
||||
|
||||
def copy_seed(gen_info_string: str, index):
|
||||
res = -1
|
||||
try:
|
||||
gen_info = json.loads(gen_info_string)
|
||||
infotext = gen_info.get('infotexts')[index]
|
||||
gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
|
||||
res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
|
||||
except Exception:
|
||||
if gen_info_string:
|
||||
errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
|
||||
|
||||
return [res, gr.update()]
|
||||
|
||||
reuse_seed.click(
|
||||
fn=copy_seed,
|
||||
_js="(x, y) => [x, selected_gallery_index()]",
|
||||
show_progress=False,
|
||||
inputs=[generation_info, seed],
|
||||
outputs=[seed, seed]
|
||||
)
|
||||
46
modules/profiling.py
Executable file
46
modules/profiling.py
Executable file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
|
||||
from modules import shared, ui_gradio_extensions
|
||||
|
||||
|
||||
class Profiler:
|
||||
def __init__(self):
|
||||
if not shared.opts.profiling_enable:
|
||||
self.profiler = None
|
||||
return
|
||||
|
||||
activities = []
|
||||
if "CPU" in shared.opts.profiling_activities:
|
||||
activities.append(torch.profiler.ProfilerActivity.CPU)
|
||||
if "CUDA" in shared.opts.profiling_activities:
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
if not activities:
|
||||
self.profiler = None
|
||||
return
|
||||
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=shared.opts.profiling_record_shapes,
|
||||
profile_memory=shared.opts.profiling_profile_memory,
|
||||
with_stack=shared.opts.profiling_with_stack
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
if self.profiler:
|
||||
self.profiler.__enter__()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, exc_tb):
|
||||
if self.profiler:
|
||||
shared.state.textinfo = "Finishing profile..."
|
||||
|
||||
self.profiler.__exit__(exc_type, exc, exc_tb)
|
||||
|
||||
self.profiler.export_chrome_trace(shared.opts.profiling_filename)
|
||||
|
||||
|
||||
def webpath():
|
||||
return ui_gradio_extensions.webpath(shared.opts.profiling_filename)
|
||||
|
||||
153
modules/progress.py
Executable file
153
modules/progress.py
Executable file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
from collections import OrderedDict
|
||||
import string
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
current_task = None
|
||||
pending_tasks = OrderedDict()
|
||||
finished_tasks = []
|
||||
recorded_results = []
|
||||
recorded_results_limit = 2
|
||||
|
||||
|
||||
def start_task(id_task):
|
||||
global current_task
|
||||
|
||||
current_task = id_task
|
||||
pending_tasks.pop(id_task, None)
|
||||
|
||||
|
||||
def finish_task(id_task):
|
||||
global current_task
|
||||
|
||||
if current_task == id_task:
|
||||
current_task = None
|
||||
|
||||
finished_tasks.append(id_task)
|
||||
if len(finished_tasks) > 16:
|
||||
finished_tasks.pop(0)
|
||||
|
||||
def create_task_id(task_type):
|
||||
N = 7
|
||||
res = ''.join(random.choices(string.ascii_uppercase +
|
||||
string.digits, k=N))
|
||||
return f"task({task_type}-{res})"
|
||||
|
||||
def record_results(id_task, res):
|
||||
recorded_results.append((id_task, res))
|
||||
if len(recorded_results) > recorded_results_limit:
|
||||
recorded_results.pop(0)
|
||||
|
||||
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
class PendingTasksResponse(BaseModel):
|
||||
size: int = Field(title="Pending task size")
|
||||
tasks: List[str] = Field(title="Pending task ids")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||
live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
|
||||
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
active: bool = Field(title="Whether the task is being worked on right now")
|
||||
queued: bool = Field(title="Whether the task is in queue")
|
||||
completed: bool = Field(title="Whether the task has already finished")
|
||||
progress: float | None = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta: float | None = Field(default=None, title="ETA in secs")
|
||||
live_preview: str | None = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
|
||||
id_live_preview: int | None = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
|
||||
textinfo: str | None = Field(default=None, title="Info text", description="Info text used by WebUI.")
|
||||
|
||||
|
||||
def setup_progress_api(app):
|
||||
app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
|
||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||
|
||||
|
||||
def get_pending_tasks():
|
||||
pending_tasks_ids = list(pending_tasks)
|
||||
pending_len = len(pending_tasks_ids)
|
||||
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
|
||||
|
||||
|
||||
def progressapi(req: ProgressRequest):
|
||||
active = req.id_task == current_task
|
||||
queued = req.id_task in pending_tasks
|
||||
completed = req.id_task in finished_tasks
|
||||
|
||||
if not active:
|
||||
textinfo = "Waiting..."
|
||||
if queued:
|
||||
sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
|
||||
queue_index = sorted_queued.index(req.id_task)
|
||||
textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
|
||||
|
||||
progress = 0
|
||||
|
||||
job_count, job_no = shared.state.job_count, shared.state.job_no
|
||||
sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
|
||||
|
||||
if job_count > 0:
|
||||
progress += job_no / job_count
|
||||
if sampling_steps > 0 and job_count > 0:
|
||||
progress += 1 / job_count * sampling_step / sampling_steps
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
elapsed_since_start = time.time() - shared.state.time_start
|
||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||
|
||||
live_preview = None
|
||||
id_live_preview = req.id_live_preview
|
||||
|
||||
if opts.live_previews_enable and req.live_preview:
|
||||
shared.state.set_current_image()
|
||||
if shared.state.id_live_preview != req.id_live_preview:
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
|
||||
if opts.live_previews_image_format == "png":
|
||||
# using optimize for large images takes an enormous amount of time
|
||||
if max(*image.size) <= 256:
|
||||
save_kwargs = {"optimize": True}
|
||||
else:
|
||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||
|
||||
else:
|
||||
save_kwargs = {}
|
||||
|
||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
||||
|
||||
def restore_progress(id_task):
|
||||
while id_task == current_task or id_task in pending_tasks:
|
||||
time.sleep(0.1)
|
||||
|
||||
res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
|
||||
480
modules/prompt_parser.py
Executable file
480
modules/prompt_parser.py
Executable file
@@ -0,0 +1,480 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import namedtuple
|
||||
import lark
|
||||
|
||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
|
||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
||||
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
||||
# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
|
||||
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
||||
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
||||
|
||||
schedule_parser = lark.Lark(r"""
|
||||
!start: (prompt | /[][():]/+)*
|
||||
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
||||
!emphasized: "(" prompt ")"
|
||||
| "(" prompt ":" prompt ")"
|
||||
| "[" prompt "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||
WHITESPACE: /\s+/
|
||||
plain: /([^\\\[\]():|]|\\.)+/
|
||||
%import common.SIGNED_NUMBER -> NUMBER
|
||||
""")
|
||||
|
||||
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
|
||||
"""
|
||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||
>>> g("test")
|
||||
[[10, 'test']]
|
||||
>>> g("a [b:3]")
|
||||
[[3, 'a '], [10, 'a b']]
|
||||
>>> g("a [b: 3]")
|
||||
[[3, 'a '], [10, 'a b']]
|
||||
>>> g("a [[[b]]:2]")
|
||||
[[2, 'a '], [10, 'a [[b]]']]
|
||||
>>> g("[(a:2):3]")
|
||||
[[3, ''], [10, '(a:2)']]
|
||||
>>> g("a [b : c : 1] d")
|
||||
[[1, 'a b d'], [10, 'a c d']]
|
||||
>>> g("a[b:[c:d:2]:1]e")
|
||||
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
||||
>>> g("a [unbalanced")
|
||||
[[10, 'a [unbalanced']]
|
||||
>>> g("a [b:.5] c")
|
||||
[[5, 'a c'], [10, 'a b c']]
|
||||
>>> g("a [{b|d{:.5] c") # not handling this right now
|
||||
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||
>>> g("((a][:b:c [d:3]")
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
>>> g("[a|(b:1.1)]")
|
||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||
>>> g("[fe|]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
>>> g("[fe|||]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
|
||||
>>> g("a [b:.5] c")
|
||||
[[10, 'a b c']]
|
||||
>>> g("a [b:1.5] c")
|
||||
[[5, 'a c'], [10, 'a b c']]
|
||||
"""
|
||||
|
||||
if hires_steps is None or use_old_scheduling:
|
||||
int_offset = 0
|
||||
flt_offset = 0
|
||||
steps = base_steps
|
||||
else:
|
||||
int_offset = base_steps
|
||||
flt_offset = 1.0
|
||||
steps = hires_steps
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
res = [steps]
|
||||
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
s = tree.children[-2]
|
||||
v = float(s)
|
||||
if use_old_scheduling:
|
||||
v = v*steps if v<1 else v
|
||||
else:
|
||||
if "." in s:
|
||||
v = (v - flt_offset) * steps
|
||||
else:
|
||||
v = (v - int_offset)
|
||||
tree.children[-2] = min(steps, int(v))
|
||||
if tree.children[-2] >= 1:
|
||||
res.append(tree.children[-2])
|
||||
|
||||
def alternate(self, tree):
|
||||
res.extend(range(1, steps+1))
|
||||
|
||||
CollectSteps().visit(tree)
|
||||
return sorted(set(res))
|
||||
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
def scheduled(self, args):
|
||||
before, after, _, when, _ = args
|
||||
yield before or () if step <= when else after
|
||||
def alternate(self, args):
|
||||
args = ["" if not arg else arg for arg in args]
|
||||
yield args[(step - 1) % len(args)]
|
||||
def start(self, args):
|
||||
def flatten(x):
|
||||
if isinstance(x, str):
|
||||
yield x
|
||||
else:
|
||||
for gen in x:
|
||||
yield from flatten(gen)
|
||||
return ''.join(flatten(args))
|
||||
def plain(self, args):
|
||||
yield args[0].value
|
||||
def __default__(self, data, children, meta):
|
||||
for child in children:
|
||||
yield child
|
||||
return AtStep().transform(tree)
|
||||
|
||||
def get_schedule(prompt):
|
||||
try:
|
||||
tree = schedule_parser.parse(prompt)
|
||||
except lark.exceptions.LarkError:
|
||||
if 0:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return [[steps, prompt]]
|
||||
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
||||
|
||||
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
||||
return [promptdict[prompt] for prompt in prompts]
|
||||
|
||||
|
||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||
|
||||
|
||||
class SdConditioning(list):
|
||||
"""
|
||||
A list with prompts for stable diffusion's conditioner model.
|
||||
Can also specify width and height of created image - SDXL needs it.
|
||||
"""
|
||||
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None, distilled_cfg_scale=None):
|
||||
super().__init__()
|
||||
self.extend(prompts)
|
||||
|
||||
if copy_from is None:
|
||||
copy_from = prompts
|
||||
|
||||
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
||||
self.width = width or getattr(copy_from, 'width', None)
|
||||
self.height = height or getattr(copy_from, 'height', None)
|
||||
self.distilled_cfg_scale = distilled_cfg_scale or getattr(copy_from, 'distilled_cfg_scale', None)
|
||||
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
Input:
|
||||
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
||||
|
||||
Output:
|
||||
[
|
||||
[
|
||||
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
||||
],
|
||||
[
|
||||
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
||||
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
||||
]
|
||||
]
|
||||
"""
|
||||
res = []
|
||||
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||
cache = {}
|
||||
|
||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||
|
||||
cached = cache.get(prompt, None)
|
||||
if cached is not None:
|
||||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
if isinstance(conds, dict):
|
||||
cond = {k: v[i] for k, v in conds.items()}
|
||||
else:
|
||||
cond = conds[i]
|
||||
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
re_AND = re.compile(r"\bAND\b")
|
||||
re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||
|
||||
|
||||
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||
res_indexes = []
|
||||
|
||||
prompt_indexes = {}
|
||||
prompt_flat_list = SdConditioning(prompts)
|
||||
prompt_flat_list.clear()
|
||||
|
||||
for prompt in prompts:
|
||||
subprompts = re_AND.split(prompt)
|
||||
|
||||
indexes = []
|
||||
for subprompt in subprompts:
|
||||
match = re_weight.search(subprompt)
|
||||
|
||||
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
||||
|
||||
weight = float(weight) if weight is not None else 1.0
|
||||
|
||||
index = prompt_indexes.get(text, None)
|
||||
if index is None:
|
||||
index = len(prompt_flat_list)
|
||||
prompt_flat_list.append(text)
|
||||
prompt_indexes[text] = index
|
||||
|
||||
indexes.append((index, weight))
|
||||
|
||||
res_indexes.append(indexes)
|
||||
|
||||
return res_indexes, prompt_flat_list, prompt_indexes
|
||||
|
||||
|
||||
class ComposableScheduledPromptConditioning:
|
||||
def __init__(self, schedules, weight=1.0):
|
||||
self.schedules: list[ScheduledPromptConditioning] = schedules
|
||||
self.weight: float = weight
|
||||
|
||||
|
||||
class MulticondLearnedConditioning:
|
||||
def __init__(self, shape, batch):
|
||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
|
||||
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
||||
"""
|
||||
|
||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||
|
||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
|
||||
|
||||
res = []
|
||||
for indexes in res_indexes:
|
||||
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape=None):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
self[k] = self[k].to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def advanced_indexing(self, item):
|
||||
result = {}
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
result[k] = self[k][item]
|
||||
return DictWithShape(result)
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
is_dict = isinstance(param, dict)
|
||||
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
|
||||
if is_dict:
|
||||
for k, param in cond_schedule[target_index].cond.items():
|
||||
res[k][i] = param
|
||||
else:
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def stack_conds(tensors):
|
||||
try:
|
||||
result = torch.stack(tensors)
|
||||
except:
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
result = torch.stack(tensors)
|
||||
return result
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
tensors = []
|
||||
conds_list = []
|
||||
|
||||
for composable_prompts in c.batch:
|
||||
conds_for_batch = []
|
||||
|
||||
for composable_prompt in composable_prompts:
|
||||
target_index = 0
|
||||
for current, entry in enumerate(composable_prompt.schedules):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
|
||||
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
||||
tensors.append(composable_prompt.schedules[target_index].cond)
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
return conds_list, stacked
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:\s*([+-]?[.\d]+)\s*\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""", re.X)
|
||||
|
||||
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith('\\'):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == '(':
|
||||
round_brackets.append(len(res))
|
||||
elif text == '[':
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and round_brackets:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ')' and round_brackets:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == ']' and square_brackets:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
parts = re.split(re_break, text)
|
||||
for i, part in enumerate(parts):
|
||||
if i > 0:
|
||||
res.append(["BREAK", -1])
|
||||
res.append([part, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
||||
else:
|
||||
import torch # doctest faster
|
||||
108
modules/realesrgan_model.py
Executable file
108
modules/realesrgan_model.py
Executable file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.utils import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
self.name = "RealESRGAN"
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = get_realesrgan_models(self)
|
||||
|
||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||
for scaler in scalers:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||
if local_model_candidates:
|
||||
scaler.local_data_path = local_model_candidates[0]
|
||||
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
prepare_free_memory()
|
||||
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
# TODO: `outscale`?
|
||||
)
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
|
||||
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
),
|
||||
]
|
||||
25
modules/restart.py
Executable file
25
modules/restart.py
Executable file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
def is_restartable() -> bool:
|
||||
"""
|
||||
Return True if the webui is restartable (i.e. there is something watching to restart it with)
|
||||
"""
|
||||
return bool(os.environ.get('SD_WEBUI_RESTART'))
|
||||
|
||||
|
||||
def restart_program() -> None:
|
||||
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
||||
|
||||
tmpdir = Path(script_path) / "tmp"
|
||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||
(tmpdir / "restart").touch()
|
||||
|
||||
stop_program()
|
||||
|
||||
|
||||
def stop_program() -> None:
|
||||
os._exit(0)
|
||||
184
modules/rng.py
Executable file
184
modules/rng.py
Executable file
@@ -0,0 +1,184 @@
|
||||
import torch
|
||||
|
||||
from modules import devices, rng_philox, shared
|
||||
|
||||
|
||||
def get_noise_source_type():
|
||||
if shared.opts.forge_try_reproduce in ['ComfyUI', 'DrawThings']:
|
||||
return "CPU"
|
||||
|
||||
return shared.opts.randn_source
|
||||
|
||||
|
||||
def randn(seed, shape, generator=None):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||
|
||||
if generator is not None:
|
||||
# Forge Note:
|
||||
# If generator is not none, we must use another seed to
|
||||
# avoid global torch.rand to get same noise again.
|
||||
# Note: removing this will make DDPM sampler broken.
|
||||
manual_seed((seed + 100000) % 65536)
|
||||
else:
|
||||
manual_seed(seed)
|
||||
|
||||
if get_noise_source_type() == "NV":
|
||||
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||
|
||||
if get_noise_source_type() == "CPU" or devices.device.type == 'mps':
|
||||
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||
|
||||
return torch.randn(shape, device=devices.device, generator=generator)
|
||||
|
||||
|
||||
def randn_local(seed, shape):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||
|
||||
if get_noise_source_type() == "NV":
|
||||
rng = rng_philox.Generator(seed)
|
||||
return torch.asarray(rng.randn(shape), device=devices.device)
|
||||
|
||||
local_device = devices.cpu if get_noise_source_type() == "CPU" or devices.device.type == 'mps' else devices.device
|
||||
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
|
||||
|
||||
|
||||
def randn_like(x):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
if get_noise_source_type() == "NV":
|
||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||
|
||||
if get_noise_source_type() == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
def randn_without_seed(shape, generator=None):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
if get_noise_source_type() == "NV":
|
||||
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||
|
||||
if get_noise_source_type() == "CPU" or devices.device.type == 'mps':
|
||||
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||
|
||||
return torch.randn(shape, device=devices.device, generator=generator)
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
"""Set up a global random number generator using the specified seed."""
|
||||
|
||||
if get_noise_source_type() == "NV":
|
||||
global nv_rng
|
||||
nv_rng = rng_philox.Generator(seed)
|
||||
return
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def create_generator(seed):
|
||||
if get_noise_source_type() == "NV":
|
||||
return rng_philox.Generator(seed)
|
||||
|
||||
device = devices.cpu if get_noise_source_type() == "CPU" or devices.device.type == 'mps' else devices.device
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return generator
|
||||
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||
dot = (low_norm*high_norm).sum(1)
|
||||
|
||||
if dot.mean() > 0.9995:
|
||||
return low * val + high * (1 - val)
|
||||
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class ImageRNG:
|
||||
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||
self.shape = tuple(map(int, shape))
|
||||
self.seeds = seeds
|
||||
self.subseeds = subseeds
|
||||
self.subseed_strength = subseed_strength
|
||||
self.seed_resize_from_h = seed_resize_from_h
|
||||
self.seed_resize_from_w = seed_resize_from_w
|
||||
|
||||
self.generators = [create_generator(seed) for seed in seeds]
|
||||
|
||||
self.is_first = True
|
||||
|
||||
def first(self):
|
||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
|
||||
|
||||
xs = []
|
||||
|
||||
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
|
||||
subnoise = None
|
||||
if self.subseeds is not None and self.subseed_strength != 0:
|
||||
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
|
||||
subnoise = randn(subseed, noise_shape)
|
||||
|
||||
if noise_shape != self.shape:
|
||||
noise = randn(seed, noise_shape)
|
||||
else:
|
||||
noise = randn(seed, self.shape, generator=generator)
|
||||
|
||||
if subnoise is not None:
|
||||
noise = slerp(self.subseed_strength, noise, subnoise)
|
||||
|
||||
if noise_shape != self.shape:
|
||||
x = randn(seed, self.shape, generator=generator)
|
||||
dx = (self.shape[2] - noise_shape[2]) // 2
|
||||
dy = (self.shape[1] - noise_shape[1]) // 2
|
||||
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||
tx = 0 if dx < 0 else dx
|
||||
ty = 0 if dy < 0 else dy
|
||||
dx = max(-dx, 0)
|
||||
dy = max(-dy, 0)
|
||||
|
||||
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
|
||||
noise = x
|
||||
|
||||
xs.append(noise)
|
||||
|
||||
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
|
||||
if eta_noise_seed_delta:
|
||||
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
|
||||
|
||||
return torch.stack(xs).to(shared.device)
|
||||
|
||||
def next(self):
|
||||
if self.is_first:
|
||||
self.is_first = False
|
||||
return self.first()
|
||||
|
||||
xs = []
|
||||
for generator in self.generators:
|
||||
x = randn_without_seed(self.shape, generator=generator)
|
||||
xs.append(x)
|
||||
|
||||
return torch.stack(xs).to(shared.device)
|
||||
|
||||
|
||||
devices.randn = randn
|
||||
devices.randn_local = randn_local
|
||||
devices.randn_like = randn_like
|
||||
devices.randn_without_seed = randn_without_seed
|
||||
devices.manual_seed = manual_seed
|
||||
102
modules/rng_philox.py
Executable file
102
modules/rng_philox.py
Executable file
@@ -0,0 +1,102 @@
|
||||
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
g = Generator(seed=0)
|
||||
print(g.randn(shape=(3, 4)))
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||
|
||||
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||
|
||||
|
||||
def uint32(x):
|
||||
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||
|
||||
|
||||
def philox4_round(counter, key):
|
||||
"""A single round of the Philox 4x32 random number generator."""
|
||||
|
||||
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||
|
||||
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||
counter[1] = v2[0]
|
||||
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||
counter[3] = v1[0]
|
||||
|
||||
|
||||
def philox4_32(counter, key, rounds=10):
|
||||
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
|
||||
Parameters:
|
||||
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||
rounds (int): The number of rounds to perform.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
"""
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
philox4_round(counter, key)
|
||||
|
||||
key[0] = key[0] + philox_w[0]
|
||||
key[1] = key[1] + philox_w[1]
|
||||
|
||||
philox4_round(counter, key)
|
||||
return counter
|
||||
|
||||
|
||||
def box_muller(x, y):
|
||||
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||
|
||||
s = np.sqrt(-2.0 * np.log(u))
|
||||
|
||||
r1 = s * np.sin(v)
|
||||
return r1.astype(np.float32)
|
||||
|
||||
|
||||
class Generator:
|
||||
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.offset = 0
|
||||
|
||||
def randn(self, shape):
|
||||
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||
|
||||
n = 1
|
||||
for x in shape:
|
||||
n *= x
|
||||
|
||||
counter = np.zeros((4, n), dtype=np.uint32)
|
||||
counter[0] = self.offset
|
||||
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||
self.offset += 1
|
||||
|
||||
key = np.empty(n, dtype=np.uint64)
|
||||
key.fill(self.seed)
|
||||
key = uint32(key)
|
||||
|
||||
g = philox4_32(counter, key)
|
||||
|
||||
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
||||
196
modules/safe.py
Executable file
196
modules/safe.py
Executable file
@@ -0,0 +1,196 @@
|
||||
# this code is adapted from the script contributed by anon from /h/
|
||||
|
||||
import pickle
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import numpy
|
||||
import _codecs
|
||||
import zipfile
|
||||
import re
|
||||
|
||||
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
from modules import errors
|
||||
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
def encode(*args):
|
||||
out = _codecs.encode(*args)
|
||||
return out
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
|
||||
try:
|
||||
return TypedStorage(_internal=True)
|
||||
except TypeError:
|
||||
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
return getattr(numpy.core.multiarray, name)
|
||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
return getattr(numpy, name)
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
import pytorch_lightning.callbacks
|
||||
return pytorch_lightning.callbacks.model_checkpoint
|
||||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|byteorder|.data/serialization_id|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for _ in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this function is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
|
||||
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
errors.report(
|
||||
f"Error verifying pickled file from {filename}\n"
|
||||
f"The file may be malicious, so the program is not going to read it.\n"
|
||||
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
# torch.load = load <- Forge do not need it!
|
||||
global_extra_handler = None
|
||||
613
modules/script_callbacks.py
Executable file
613
modules/script_callbacks.py
Executable file
@@ -0,0 +1,613 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
from modules import errors, timer, extensions, shared, util
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class ExtraNoiseParams:
|
||||
def __init__(self, noise, x, xi):
|
||||
self.noise = noise
|
||||
"""Random noise generated by the seed"""
|
||||
|
||||
self.x = x
|
||||
"""Latent representation of the image"""
|
||||
|
||||
self.xi = xi
|
||||
"""Noisy latent representation of the image"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
self.text_cond = text_cond
|
||||
""" Encoder hidden states of text conditioning from prompt"""
|
||||
|
||||
self.text_uncond = text_uncond
|
||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||
|
||||
self.denoiser = denoiser
|
||||
"""Current CFGDenoiser object with processing parameters"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
self.inner_model = inner_model
|
||||
"""Inner model reference used for denoising"""
|
||||
|
||||
|
||||
class AfterCFGCallbackParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BeforeTokenCounterParams:
|
||||
prompt: str
|
||||
steps: int
|
||||
styles: list
|
||||
|
||||
is_positive: bool = True
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScriptCallback:
|
||||
script: str
|
||||
callback: any
|
||||
name: str = "unnamed"
|
||||
|
||||
|
||||
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
|
||||
if filename is None:
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
|
||||
extension = extensions.find_extension(filename)
|
||||
extension_name = extension.canonical_name if extension else 'base'
|
||||
|
||||
callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
|
||||
if name is not None:
|
||||
callback_name += f'/{name}'
|
||||
|
||||
unique_callback_name = callback_name
|
||||
for index in range(1000):
|
||||
existing = any(x.name == unique_callback_name for x in callbacks)
|
||||
if not existing:
|
||||
break
|
||||
|
||||
unique_callback_name = f'{callback_name}-{index+1}'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
|
||||
|
||||
|
||||
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
|
||||
callbacks = unordered_callbacks.copy()
|
||||
callback_lookup = {x.name: x for x in callbacks}
|
||||
dependencies = {}
|
||||
|
||||
order_instructions = {}
|
||||
for extension in extensions.extensions:
|
||||
for order_instruction in extension.metadata.list_callback_order_instructions():
|
||||
if order_instruction.name in callback_lookup:
|
||||
if order_instruction.name not in order_instructions:
|
||||
order_instructions[order_instruction.name] = []
|
||||
|
||||
order_instructions[order_instruction.name].append(order_instruction)
|
||||
|
||||
if order_instructions:
|
||||
for callback in callbacks:
|
||||
dependencies[callback.name] = []
|
||||
|
||||
for callback in callbacks:
|
||||
for order_instruction in order_instructions.get(callback.name, []):
|
||||
for after in order_instruction.after:
|
||||
if after not in callback_lookup:
|
||||
continue
|
||||
|
||||
dependencies[callback.name].append(after)
|
||||
|
||||
for before in order_instruction.before:
|
||||
if before not in callback_lookup:
|
||||
continue
|
||||
|
||||
dependencies[before].append(callback.name)
|
||||
|
||||
sorted_names = util.topological_sort(dependencies)
|
||||
callbacks = [callback_lookup[x] for x in sorted_names]
|
||||
|
||||
if enable_user_sort:
|
||||
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
|
||||
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
|
||||
if index is not None:
|
||||
callbacks.insert(0, callbacks.pop(index))
|
||||
|
||||
return callbacks
|
||||
|
||||
|
||||
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
|
||||
if unordered_callbacks is None:
|
||||
unordered_callbacks = callback_map.get('callbacks_' + category, [])
|
||||
|
||||
if not enable_user_sort:
|
||||
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
|
||||
|
||||
callbacks = ordered_callbacks_map.get(category)
|
||||
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
|
||||
return callbacks
|
||||
|
||||
callbacks = sort_callbacks(category, unordered_callbacks)
|
||||
|
||||
ordered_callbacks_map[category] = callbacks
|
||||
return callbacks
|
||||
|
||||
|
||||
def enumerate_callbacks():
|
||||
for category, callbacks in callback_map.items():
|
||||
if category.startswith('callbacks_'):
|
||||
category = category[10:]
|
||||
|
||||
yield category, callbacks
|
||||
|
||||
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_extra_noise=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_cfg_denoised=[],
|
||||
callbacks_cfg_after_cfg=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
callbacks_infotext_pasted=[],
|
||||
callbacks_script_unloaded=[],
|
||||
callbacks_before_ui=[],
|
||||
callbacks_on_reload=[],
|
||||
callbacks_list_optimizers=[],
|
||||
callbacks_list_unets=[],
|
||||
callbacks_before_token_counter=[],
|
||||
)
|
||||
|
||||
ordered_callbacks_map = {}
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
ordered_callbacks_map.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in ordered_callbacks('app_started'):
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
timer.startup_timer.record(os.path.basename(c.script))
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def app_reload_callback():
|
||||
for c in ordered_callbacks('on_reload'):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_on_reload')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in ordered_callbacks('model_loaded'):
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
report_exception(c, 'model_loaded_callback')
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in ordered_callbacks('ui_tabs'):
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
report_exception(c, 'ui_tabs_callback')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in ordered_callbacks('ui_train_tabs'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in ordered_callbacks('ui_settings'):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in ordered_callbacks('before_image_saved'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in ordered_callbacks('image_saved'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def extra_noise_callback(params: ExtraNoiseParams):
|
||||
for c in ordered_callbacks('extra_noise'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_extra_noise')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in ordered_callbacks('cfg_denoiser'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def cfg_denoised_callback(params: CFGDenoisedParams):
|
||||
for c in ordered_callbacks('cfg_denoised'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoised_callback')
|
||||
|
||||
|
||||
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||
for c in ordered_callbacks('cfg_after_cfg'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_after_cfg_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in ordered_callbacks('before_component'):
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'before_component_callback')
|
||||
|
||||
|
||||
def after_component_callback(component, **kwargs):
|
||||
for c in ordered_callbacks('after_component'):
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in ordered_callbacks('image_grid'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
||||
for c in ordered_callbacks('infotext_pasted'):
|
||||
try:
|
||||
c.callback(infotext, params)
|
||||
except Exception:
|
||||
report_exception(c, 'infotext_pasted')
|
||||
|
||||
|
||||
def script_unloaded_callback():
|
||||
for c in reversed(ordered_callbacks('script_unloaded')):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'script_unloaded')
|
||||
|
||||
|
||||
def before_ui_callback():
|
||||
for c in reversed(ordered_callbacks('before_ui')):
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'before_ui')
|
||||
|
||||
|
||||
def list_optimizers_callback():
|
||||
res = []
|
||||
|
||||
for c in ordered_callbacks('list_optimizers'):
|
||||
try:
|
||||
c.callback(res)
|
||||
except Exception:
|
||||
report_exception(c, 'list_optimizers')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def list_unets_callback():
|
||||
res = []
|
||||
|
||||
for c in ordered_callbacks('list_unets'):
|
||||
try:
|
||||
c.callback(res)
|
||||
except Exception:
|
||||
report_exception(c, 'list_unets')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def before_token_counter_callback(params: BeforeTokenCounterParams):
|
||||
for c in ordered_callbacks('before_token_counter'):
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_token_counter')
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
for ordered_callbacks_list in ordered_callbacks_map.values():
|
||||
for callback_to_remove in [cb for cb in ordered_callbacks_list if cb.script == filename]:
|
||||
ordered_callbacks_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def remove_callbacks_for_function(callback_func):
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
for ordered_callback_list in ordered_callbacks_map.values():
|
||||
for callback_to_remove in [cb for cb in ordered_callback_list if cb.callback == callback_func]:
|
||||
ordered_callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def on_app_started(callback, *, name=None):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
|
||||
|
||||
|
||||
def on_before_reload(callback, *, name=None):
|
||||
"""register a function to be called just before the server reloads."""
|
||||
add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
|
||||
|
||||
|
||||
def on_model_loaded(callback, *, name=None):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument; this function is also called when the script is reloaded. """
|
||||
add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
|
||||
|
||||
|
||||
def on_ui_tabs(callback, *, name=None):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback, *, name=None):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
|
||||
|
||||
|
||||
def on_ui_settings(callback, *, name=None):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
|
||||
|
||||
|
||||
def on_before_image_saved(callback, *, name=None):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
|
||||
|
||||
|
||||
def on_image_saved(callback, *, name=None):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
|
||||
|
||||
|
||||
def on_extra_noise(callback, *, name=None):
|
||||
"""register a function to be called before adding extra noise in img2img or hires fix;
|
||||
The callback is called with one argument:
|
||||
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
||||
"""
|
||||
add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback, *, name=None):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
|
||||
|
||||
|
||||
def on_cfg_denoised(callback, *, name=None):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
|
||||
|
||||
|
||||
def on_cfg_after_cfg(callback, *, name=None):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||
The callback is called with one argument:
|
||||
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
|
||||
|
||||
|
||||
def on_before_component(callback, *, name=None):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
- component - gradio component that is about to be created.
|
||||
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
|
||||
|
||||
|
||||
def on_after_component(callback, *, name=None):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
|
||||
|
||||
|
||||
def on_image_grid(callback, *, name=None):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
|
||||
|
||||
|
||||
def on_infotext_pasted(callback, *, name=None):
|
||||
"""register a function to be called before applying an infotext.
|
||||
The callback is called with two arguments:
|
||||
- infotext: str - raw infotext.
|
||||
- result: dict[str, any] - parsed infotext parameters.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
|
||||
|
||||
|
||||
def on_script_unloaded(callback, *, name=None):
|
||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||
the script did should be reverted here"""
|
||||
|
||||
add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
|
||||
|
||||
|
||||
def on_before_ui(callback, *, name=None):
|
||||
"""register a function to be called before the UI is created."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
|
||||
|
||||
|
||||
def on_list_optimizers(callback, *, name=None):
|
||||
"""register a function to be called when UI is making a list of cross attention optimization options.
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
||||
to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
|
||||
|
||||
|
||||
def on_list_unets(callback, *, name=None):
|
||||
"""register a function to be called when UI is making a list of alternative options for unet.
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
|
||||
|
||||
|
||||
def on_before_token_counter(callback, *, name=None):
|
||||
"""register a function to be called when UI is counting tokens for a prompt.
|
||||
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')
|
||||
35
modules/script_loading.py
Executable file
35
modules/script_loading.py
Executable file
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
from modules import errors
|
||||
|
||||
|
||||
loaded_scripts = {}
|
||||
|
||||
|
||||
def load_module(path):
|
||||
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
loaded_scripts[path] = module
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser, extension_list=None):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
|
||||
for dirname in sorted(extensions):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
errors.report(f"Error running preload() for {preload_script}", exc_info=True)
|
||||
1068
modules/scripts.py
Executable file
1068
modules/scripts.py
Executable file
File diff suppressed because it is too large
Load Diff
42
modules/scripts_auto_postprocessing.py
Executable file
42
modules/scripts_auto_postprocessing.py
Executable file
@@ -0,0 +1,42 @@
|
||||
from modules import scripts, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
||||
def title(self):
|
||||
return self.script.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
self.script.process(pp, **args_dict)
|
||||
p.extra_generation_params.update(pp.info)
|
||||
script_pp.image = pp.image
|
||||
|
||||
|
||||
def create_auto_preprocessing_script_data():
|
||||
from modules import scripts
|
||||
|
||||
res = []
|
||||
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
|
||||
return res
|
||||
230
modules/scripts_postprocessing.py
Executable file
230
modules/scripts_postprocessing.py
Executable file
@@ -0,0 +1,230 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors, shared
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PostprocessedImageSharedInfo:
|
||||
target_width: int = None
|
||||
target_height: int = None
|
||||
|
||||
|
||||
class PostprocessedImage:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
self.info = {}
|
||||
self.shared = PostprocessedImageSharedInfo()
|
||||
self.extra_images = []
|
||||
self.nametags = []
|
||||
self.disable_processing = False
|
||||
self.caption = None
|
||||
|
||||
def get_suffix(self, used_suffixes=None):
|
||||
used_suffixes = {} if used_suffixes is None else used_suffixes
|
||||
suffix = "-".join(self.nametags)
|
||||
if suffix:
|
||||
suffix = "-" + suffix
|
||||
|
||||
if suffix not in used_suffixes:
|
||||
used_suffixes[suffix] = 1
|
||||
return suffix
|
||||
|
||||
for i in range(1, 100):
|
||||
proposed_suffix = suffix + "-" + str(i)
|
||||
|
||||
if proposed_suffix not in used_suffixes:
|
||||
used_suffixes[proposed_suffix] = 1
|
||||
return proposed_suffix
|
||||
|
||||
return suffix
|
||||
|
||||
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
|
||||
pp = PostprocessedImage(new_image)
|
||||
pp.shared = self.shared
|
||||
pp.nametags = self.nametags.copy()
|
||||
pp.info = self.info.copy()
|
||||
pp.disable_processing = disable_processing
|
||||
|
||||
if nametags is not None:
|
||||
pp.nametags += nametags
|
||||
|
||||
return pp
|
||||
|
||||
|
||||
class ScriptPostprocessing:
|
||||
filename = None
|
||||
controls = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
|
||||
order = 1000
|
||||
"""scripts will be ordred by this value in postprocessing UI"""
|
||||
|
||||
name = None
|
||||
"""this function should return the title of the script."""
|
||||
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
|
||||
def ui(self):
|
||||
"""
|
||||
This function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be a dictionary that maps parameter names to components used in processing.
|
||||
Values of those components will be passed to process() function.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process(self, pp: PostprocessedImage, **args):
|
||||
"""
|
||||
This function is called to postprocess the image.
|
||||
args contains a dictionary with all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process_firstpass(self, pp: PostprocessedImage, **args):
|
||||
"""
|
||||
Called for all scripts before calling process(). Scripts can examine the image here and set fields
|
||||
of the pp object to communicate things to other scripts.
|
||||
args contains a dictionary with all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def image_changed(self):
|
||||
pass
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
except Exception as e:
|
||||
errors.display(e, f"calling {filename}/{funcname}")
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class ScriptPostprocessingRunner:
|
||||
def __init__(self):
|
||||
self.scripts = None
|
||||
self.ui_created = False
|
||||
|
||||
def initialize_scripts(self, scripts_data):
|
||||
self.scripts = []
|
||||
|
||||
for script_data in scripts_data:
|
||||
script: ScriptPostprocessing = script_data.script_class()
|
||||
script.filename = script_data.path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
|
||||
self.scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script, inputs):
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
script.controls = wrap_call(script.ui, script.filename, "ui")
|
||||
|
||||
for control in script.controls.values():
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
|
||||
inputs += list(script.controls.values())
|
||||
script.args_to = len(inputs)
|
||||
|
||||
def scripts_in_preferred_order(self):
|
||||
if self.scripts is None:
|
||||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
scripts_filter_out = set(shared.opts.postprocessing_disable_in_extras)
|
||||
|
||||
def script_score(name):
|
||||
for i, possible_match in enumerate(scripts_order):
|
||||
if possible_match == name:
|
||||
return i
|
||||
|
||||
return len(self.scripts)
|
||||
|
||||
filtered_scripts = [script for script in self.scripts if script.name not in scripts_filter_out]
|
||||
script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(filtered_scripts)}
|
||||
|
||||
return sorted(filtered_scripts, key=lambda x: script_scores[x.name])
|
||||
|
||||
def setup_ui(self):
|
||||
inputs = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Row() as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
|
||||
script.group = group
|
||||
|
||||
self.ui_created = True
|
||||
return inputs
|
||||
|
||||
def run(self, pp: PostprocessedImage, args):
|
||||
scripts = []
|
||||
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
scripts.append((script, process_args))
|
||||
|
||||
for script, process_args in scripts:
|
||||
script.process_firstpass(pp, **process_args)
|
||||
|
||||
all_images = [pp]
|
||||
|
||||
for script, process_args in scripts:
|
||||
if shared.state.skipped:
|
||||
break
|
||||
|
||||
shared.state.job = script.name
|
||||
|
||||
for single_image in all_images.copy():
|
||||
|
||||
if not single_image.disable_processing:
|
||||
script.process(single_image, **process_args)
|
||||
|
||||
for extra_image in single_image.extra_images:
|
||||
if not isinstance(extra_image, PostprocessedImage):
|
||||
extra_image = single_image.create_copy(extra_image)
|
||||
|
||||
all_images.append(extra_image)
|
||||
|
||||
single_image.extra_images.clear()
|
||||
|
||||
pp.extra_images = all_images[1:]
|
||||
|
||||
def create_args_for_run(self, scripts_args):
|
||||
if not self.ui_created:
|
||||
with gr.Blocks(analytics_enabled=False):
|
||||
self.setup_ui()
|
||||
|
||||
scripts = self.scripts_in_preferred_order()
|
||||
args = [None] * max([x.args_to for x in scripts])
|
||||
|
||||
for script in scripts:
|
||||
script_args_dict = scripts_args.get(script.name, None)
|
||||
if script_args_dict is not None:
|
||||
|
||||
for i, name in enumerate(script.controls):
|
||||
args[script.args_from + i] = script_args_dict.get(name, None)
|
||||
|
||||
return args
|
||||
|
||||
def image_changed(self):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script.image_changed()
|
||||
|
||||
232
modules/sd_disable_initialization.py
Executable file
232
modules/sd_disable_initialization.py
Executable file
@@ -0,0 +1,232 @@
|
||||
# import ldm.modules.encoders.modules
|
||||
# import open_clip
|
||||
# import torch
|
||||
# import transformers.utils.hub
|
||||
#
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# class ReplaceHelper:
|
||||
# def __init__(self):
|
||||
# self.replaced = []
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def restore(self):
|
||||
# for obj, field, original in self.replaced:
|
||||
# setattr(obj, field, original)
|
||||
#
|
||||
# self.replaced.clear()
|
||||
#
|
||||
#
|
||||
# class DisableInitialization(ReplaceHelper):
|
||||
# """
|
||||
# When an object of this class enters a `with` block, it starts:
|
||||
# - preventing torch's layer initialization functions from working
|
||||
# - changes CLIP and OpenCLIP to not download model weights
|
||||
# - changes CLIP to not make requests to check if there is a new version of a file you already have
|
||||
#
|
||||
# When it leaves the block, it reverts everything to how it was before.
|
||||
#
|
||||
# Use it like this:
|
||||
# ```
|
||||
# with DisableInitialization():
|
||||
# do_things()
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, disable_clip=True):
|
||||
# super().__init__()
|
||||
# self.disable_clip = disable_clip
|
||||
#
|
||||
# def replace(self, obj, field, func):
|
||||
# original = getattr(obj, field, None)
|
||||
# if original is None:
|
||||
# return None
|
||||
#
|
||||
# self.replaced.append((obj, field, original))
|
||||
# setattr(obj, field, func)
|
||||
#
|
||||
# return original
|
||||
#
|
||||
# def __enter__(self):
|
||||
# def do_nothing(*args, **kwargs):
|
||||
# pass
|
||||
#
|
||||
# def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
|
||||
# return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
|
||||
#
|
||||
# def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
|
||||
# res.name_or_path = pretrained_model_name_or_path
|
||||
# return res
|
||||
#
|
||||
# def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
|
||||
# args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
|
||||
# return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
|
||||
#
|
||||
# # this file is always 404, prevent making request
|
||||
# if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
|
||||
# return None
|
||||
#
|
||||
# try:
|
||||
# res = original(url, *args, local_files_only=True, **kwargs)
|
||||
# if res is None:
|
||||
# res = original(url, *args, local_files_only=False, **kwargs)
|
||||
# return res
|
||||
# except Exception:
|
||||
# return original(url, *args, local_files_only=False, **kwargs)
|
||||
#
|
||||
# def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
|
||||
# return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
|
||||
#
|
||||
# self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
# self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
#
|
||||
# if self.disable_clip:
|
||||
# self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
# self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
# self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
# self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
# self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
# self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class InitializeOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
# which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
# will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.InitializeOnMeta():
|
||||
# sd_model = instantiate_from_config(sd_config.model)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# def set_device(x):
|
||||
# x["device"] = "meta"
|
||||
# return x
|
||||
#
|
||||
# linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
# conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
# mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
# self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
#
|
||||
#
|
||||
# class LoadStateDictOnMeta(ReplaceHelper):
|
||||
# """
|
||||
# Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
# As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
# Meant to be used together with InitializeOnMeta above.
|
||||
#
|
||||
# Usage:
|
||||
# ```
|
||||
# with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
# ```
|
||||
# """
|
||||
#
|
||||
# def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
# super().__init__()
|
||||
# self.state_dict = state_dict
|
||||
# self.device = device
|
||||
# self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
# self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
#
|
||||
# def get_weight_dtype(self, key):
|
||||
# key_first_term, _ = key.split('.', 1)
|
||||
# return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
#
|
||||
# def __enter__(self):
|
||||
# if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
# return
|
||||
#
|
||||
# sd = self.state_dict
|
||||
# device = self.device
|
||||
#
|
||||
# def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
# used_param_keys = []
|
||||
#
|
||||
# for name, param in module._parameters.items():
|
||||
# if param is None:
|
||||
# continue
|
||||
#
|
||||
# key = prefix + name
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# if param.is_meta:
|
||||
# dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
# module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
#
|
||||
# for name in module._buffers:
|
||||
# key = prefix + name
|
||||
#
|
||||
# sd_param = sd.pop(key, None)
|
||||
# if sd_param is not None:
|
||||
# state_dict[key] = sd_param
|
||||
# used_param_keys.append(key)
|
||||
#
|
||||
# original(module, state_dict, prefix, *args, **kwargs)
|
||||
#
|
||||
# for key in used_param_keys:
|
||||
# state_dict.pop(key, None)
|
||||
#
|
||||
# def load_state_dict(original, module, state_dict, strict=True):
|
||||
# """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
# because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
# all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
#
|
||||
# In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
#
|
||||
# The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
# the function and does not call the original) the state dict will just fail to load because weights
|
||||
# would be on the meta device.
|
||||
# """
|
||||
#
|
||||
# if state_dict is sd:
|
||||
# state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
#
|
||||
# original(module, state_dict, strict=strict)
|
||||
#
|
||||
# module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
# module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
# linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
# conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
# mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
# layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
# group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
#
|
||||
# def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# self.restore()
|
||||
70
modules/sd_emphasis.py
Executable file
70
modules/sd_emphasis.py
Executable file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
class Emphasis:
|
||||
"""Emphasis class decides how to death with (emphasized:1.1) text in prompts"""
|
||||
|
||||
name: str = "Base"
|
||||
description: str = ""
|
||||
|
||||
tokens: list[list[int]]
|
||||
"""tokens from the chunk of the prompt"""
|
||||
|
||||
multipliers: torch.Tensor
|
||||
"""tensor with multipliers, once for each token"""
|
||||
|
||||
z: torch.Tensor
|
||||
"""output of cond transformers network (CLIP)"""
|
||||
|
||||
def after_transformers(self):
|
||||
"""Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmphasisNone(Emphasis):
|
||||
name = "None"
|
||||
description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
|
||||
|
||||
|
||||
class EmphasisIgnore(Emphasis):
|
||||
name = "Ignore"
|
||||
description = "treat all empasised words as if they have no emphasis"
|
||||
|
||||
|
||||
class EmphasisOriginal(Emphasis):
|
||||
name = "Original"
|
||||
description = "the original emphasis implementation"
|
||||
|
||||
def after_transformers(self):
|
||||
original_mean = self.z.mean()
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
new_mean = self.z.mean()
|
||||
self.z = self.z * (original_mean / new_mean)
|
||||
|
||||
|
||||
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
||||
name = "No norm"
|
||||
description = "same as original, but without normalization (seems to work better for SDXL)"
|
||||
|
||||
def after_transformers(self):
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
|
||||
def get_current_option(emphasis_option_name):
|
||||
return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
|
||||
|
||||
|
||||
def get_options_descriptions():
|
||||
return ", ".join(f"{x.name}: {x.description}" for x in options)
|
||||
|
||||
|
||||
options = [
|
||||
EmphasisNone,
|
||||
EmphasisIgnore,
|
||||
EmphasisOriginal,
|
||||
EmphasisOriginalNoNorm,
|
||||
]
|
||||
259
modules/sd_hijack.py
Executable file
259
modules/sd_hijack.py
Executable file
@@ -0,0 +1,259 @@
|
||||
class StableDiffusionModelHijack:
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
pass
|
||||
|
||||
def convert_sdxl_to_ssd(self, m):
|
||||
pass
|
||||
|
||||
def hijack(self, m):
|
||||
pass
|
||||
|
||||
def undo_hijack(self, m):
|
||||
pass
|
||||
|
||||
def apply_circular(self, enable):
|
||||
pass
|
||||
|
||||
def clear_comments(self):
|
||||
pass
|
||||
|
||||
def get_prompt_lengths(self, text, cond_stage_model):
|
||||
from modules import shared
|
||||
return shared.sd_model.get_prompt_lengths_on_ui(text)
|
||||
|
||||
def redo_hijack(self, m):
|
||||
pass
|
||||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
||||
# import torch
|
||||
# from torch.nn.functional import silu
|
||||
# from types import MethodType
|
||||
#
|
||||
# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
# from modules.hypernetworks import hypernetwork
|
||||
# from modules.shared import cmd_opts
|
||||
# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.model
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
# import ldm.models.diffusion.ddpm
|
||||
# import ldm.models.diffusion.ddim
|
||||
# import ldm.models.diffusion.plms
|
||||
# import ldm.modules.encoders.modules
|
||||
#
|
||||
# import sgm.modules.attention
|
||||
# import sgm.modules.diffusionmodules.model
|
||||
# import sgm.modules.diffusionmodules.openaimodel
|
||||
# import sgm.modules.encoders.modules
|
||||
#
|
||||
# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
#
|
||||
# # new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
#
|
||||
# # silence new console spam from SD2
|
||||
# ldm.modules.attention.print = shared.ldm_print
|
||||
# ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# ldm.util.print = shared.ldm_print
|
||||
# ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
#
|
||||
# optimizers = []
|
||||
# current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
#
|
||||
# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||
#
|
||||
# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||
# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||
#
|
||||
#
|
||||
# def list_optimizers():
|
||||
# new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
#
|
||||
# new_optimizers = [x for x in new_optimizers if x.is_available()]
|
||||
#
|
||||
# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
||||
#
|
||||
# optimizers.clear()
|
||||
# optimizers.extend(new_optimizers)
|
||||
#
|
||||
#
|
||||
# def apply_optimizations(option=None):
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def undo_optimizations():
|
||||
# return
|
||||
#
|
||||
#
|
||||
# def fix_checkpoint():
|
||||
# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
# checkpoints to be added when not training (there's a warning)"""
|
||||
#
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# def weighted_loss(sd_model, pred, target, mean=True):
|
||||
# #Calculate the weight normally, but ignore the mean
|
||||
# loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||
#
|
||||
# #Check if we have weights available
|
||||
# weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||
# if weight is not None:
|
||||
# loss *= weight
|
||||
#
|
||||
# #Return the loss, as mean if specified
|
||||
# return loss.mean() if mean else loss
|
||||
#
|
||||
# def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
||||
# try:
|
||||
# #Temporarily append weights to a place accessible during loss calc
|
||||
# sd_model._custom_loss_weight = w
|
||||
#
|
||||
# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||
# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||
# if not hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model._old_get_loss = sd_model.get_loss
|
||||
# sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
||||
#
|
||||
# #Run the standard forward function, but with the patched 'get_loss'
|
||||
# return sd_model.forward(x, c, *args, **kwargs)
|
||||
# finally:
|
||||
# try:
|
||||
# #Delete temporary weights if appended
|
||||
# del sd_model._custom_loss_weight
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
# #If we have an old loss function, reset the loss function to the original one
|
||||
# if hasattr(sd_model, '_old_get_loss'):
|
||||
# sd_model.get_loss = sd_model._old_get_loss
|
||||
# del sd_model._old_get_loss
|
||||
#
|
||||
# def apply_weighted_forward(sd_model):
|
||||
# #Add new function 'weighted_forward' that can be called to calc weighted loss
|
||||
# sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
||||
#
|
||||
# def undo_weighted_forward(sd_model):
|
||||
# try:
|
||||
# del sd_model.weighted_forward
|
||||
# except AttributeError:
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class StableDiffusionModelHijack:
|
||||
# fixes = None
|
||||
# layers = None
|
||||
# circular_enabled = False
|
||||
# clip = None
|
||||
# optimization_method = None
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.extra_generation_params = {}
|
||||
# self.comments = []
|
||||
#
|
||||
# def apply_optimizations(self, option=None):
|
||||
# pass
|
||||
#
|
||||
# def convert_sdxl_to_ssd(self, m):
|
||||
# pass
|
||||
#
|
||||
# def hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def undo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
# def apply_circular(self, enable):
|
||||
# pass
|
||||
#
|
||||
# def clear_comments(self):
|
||||
# self.comments = []
|
||||
# self.extra_generation_params = {}
|
||||
#
|
||||
# def get_prompt_lengths(self, text, cond_stage_model):
|
||||
# pass
|
||||
#
|
||||
# def redo_hijack(self, m):
|
||||
# pass
|
||||
#
|
||||
#
|
||||
# class EmbeddingsWithFixes(torch.nn.Module):
|
||||
# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
# super().__init__()
|
||||
# self.wrapped = wrapped
|
||||
# self.embeddings = embeddings
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
# self.weight = self.wrapped.weight
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# batch_fixes = self.embeddings.fixes
|
||||
# self.embeddings.fixes = None
|
||||
#
|
||||
# inputs_embeds = self.wrapped(input_ids)
|
||||
#
|
||||
# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||
# return inputs_embeds
|
||||
#
|
||||
# vecs = []
|
||||
# for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
# for offset, embedding in fixes:
|
||||
# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
# emb = devices.cond_cast_unet(vec)
|
||||
# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
|
||||
#
|
||||
# vecs.append(tensor)
|
||||
#
|
||||
# return torch.stack(vecs)
|
||||
#
|
||||
#
|
||||
# class TextualInversionEmbeddings(torch.nn.Embedding):
|
||||
# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
|
||||
# super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
#
|
||||
# self.embeddings = model_hijack
|
||||
# self.textual_inversion_key = textual_inversion_key
|
||||
#
|
||||
# @property
|
||||
# def wrapped(self):
|
||||
# return super().forward
|
||||
#
|
||||
# def forward(self, input_ids):
|
||||
# return EmbeddingsWithFixes.forward(self, input_ids)
|
||||
#
|
||||
#
|
||||
# def add_circular_option_to_conv_2d():
|
||||
# conv2d_constructor = torch.nn.Conv2d.__init__
|
||||
#
|
||||
# def conv2d_constructor_circular(self, *args, **kwargs):
|
||||
# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
||||
#
|
||||
# torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
#
|
||||
# def register_buffer(self, name, attr):
|
||||
# """
|
||||
# Fix register buffer bug for Mac OS.
|
||||
# """
|
||||
#
|
||||
# if type(attr) == torch.Tensor:
|
||||
# if attr.device != devices.device:
|
||||
# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
#
|
||||
# setattr(self, name, attr)
|
||||
#
|
||||
#
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
46
modules/sd_hijack_checkpoint.py
Executable file
46
modules/sd_hijack_checkpoint.py
Executable file
@@ -0,0 +1,46 @@
|
||||
# from torch.utils.checkpoint import checkpoint
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
#
|
||||
# def BasicTransformerBlock_forward(self, x, context=None):
|
||||
# return checkpoint(self._forward, x, context)
|
||||
#
|
||||
#
|
||||
# def AttentionBlock_forward(self, x):
|
||||
# return checkpoint(self._forward, x)
|
||||
#
|
||||
#
|
||||
# def ResBlock_forward(self, x, emb):
|
||||
# return checkpoint(self._forward, x, emb)
|
||||
#
|
||||
#
|
||||
# stored = []
|
||||
#
|
||||
#
|
||||
# def add():
|
||||
# if len(stored) != 0:
|
||||
# return
|
||||
#
|
||||
# stored.extend([
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
|
||||
# ])
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
|
||||
#
|
||||
#
|
||||
# def remove():
|
||||
# if len(stored) == 0:
|
||||
# return
|
||||
#
|
||||
# ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
|
||||
# ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
|
||||
# ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
|
||||
#
|
||||
# stored.clear()
|
||||
#
|
||||
384
modules/sd_hijack_clip.py
Executable file
384
modules/sd_hijack_clip.py
Executable file
@@ -0,0 +1,384 @@
|
||||
# import math
|
||||
# from collections import namedtuple
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# from modules import prompt_parser, devices, sd_hijack, sd_emphasis
|
||||
# from modules.shared import opts
|
||||
#
|
||||
#
|
||||
# class PromptChunk:
|
||||
# """
|
||||
# This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
||||
# If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
||||
# Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
||||
# so just 75 tokens from prompt.
|
||||
# """
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.tokens = []
|
||||
# self.multipliers = []
|
||||
# self.fixes = []
|
||||
#
|
||||
#
|
||||
# PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
# """An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||
# chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||
# are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||
#
|
||||
#
|
||||
# class TextConditionalModel(torch.nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
#
|
||||
# self.hijack = sd_hijack.model_hijack
|
||||
# self.chunk_length = 75
|
||||
#
|
||||
# self.is_trainable = False
|
||||
# self.input_key = 'txt'
|
||||
# self.return_pooled = False
|
||||
#
|
||||
# self.comma_token = None
|
||||
# self.id_start = None
|
||||
# self.id_end = None
|
||||
# self.id_pad = None
|
||||
#
|
||||
# def empty_chunk(self):
|
||||
# """creates an empty PromptChunk and returns it"""
|
||||
#
|
||||
# chunk = PromptChunk()
|
||||
# chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||
# chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
||||
# return chunk
|
||||
#
|
||||
# def get_target_prompt_token_count(self, token_count):
|
||||
# """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
||||
#
|
||||
# return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||
#
|
||||
# def tokenize(self, texts):
|
||||
# """Converts a batch of texts into a batch of token ids"""
|
||||
#
|
||||
# raise NotImplementedError
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# """
|
||||
# converts a batch of token ids (in python lists) into a single tensor with numeric representation of those tokens;
|
||||
# All python lists with tokens are assumed to have same length, usually 77.
|
||||
# if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||
# model - can be 768 and 1024.
|
||||
# Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||
# """
|
||||
#
|
||||
# raise NotImplementedError
|
||||
#
|
||||
# def encode_embedding_init_text(self, init_text, nvpt):
|
||||
# """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
||||
# transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
||||
#
|
||||
# raise NotImplementedError
|
||||
#
|
||||
# def tokenize_line(self, line):
|
||||
# """
|
||||
# this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
||||
# represent the prompt.
|
||||
# Returns the list and the total number of tokens in the prompt.
|
||||
# """
|
||||
#
|
||||
# if opts.emphasis != "None":
|
||||
# parsed = prompt_parser.parse_prompt_attention(line)
|
||||
# else:
|
||||
# parsed = [[line, 1.0]]
|
||||
#
|
||||
# tokenized = self.tokenize([text for text, _ in parsed])
|
||||
#
|
||||
# chunks = []
|
||||
# chunk = PromptChunk()
|
||||
# token_count = 0
|
||||
# last_comma = -1
|
||||
#
|
||||
# def next_chunk(is_last=False):
|
||||
# """puts current chunk into the list of results and produces the next one - empty;
|
||||
# if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
||||
# nonlocal token_count
|
||||
# nonlocal last_comma
|
||||
# nonlocal chunk
|
||||
#
|
||||
# if is_last:
|
||||
# token_count += len(chunk.tokens)
|
||||
# else:
|
||||
# token_count += self.chunk_length
|
||||
#
|
||||
# to_add = self.chunk_length - len(chunk.tokens)
|
||||
# if to_add > 0:
|
||||
# chunk.tokens += [self.id_end] * to_add
|
||||
# chunk.multipliers += [1.0] * to_add
|
||||
#
|
||||
# chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
||||
# chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||
#
|
||||
# last_comma = -1
|
||||
# chunks.append(chunk)
|
||||
# chunk = PromptChunk()
|
||||
#
|
||||
# for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
# if text == 'BREAK' and weight == -1:
|
||||
# next_chunk()
|
||||
# continue
|
||||
#
|
||||
# position = 0
|
||||
# while position < len(tokens):
|
||||
# token = tokens[position]
|
||||
#
|
||||
# if token == self.comma_token:
|
||||
# last_comma = len(chunk.tokens)
|
||||
#
|
||||
# # this is when we are at the end of allotted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||
# # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||
# elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
# break_location = last_comma + 1
|
||||
#
|
||||
# reloc_tokens = chunk.tokens[break_location:]
|
||||
# reloc_mults = chunk.multipliers[break_location:]
|
||||
#
|
||||
# chunk.tokens = chunk.tokens[:break_location]
|
||||
# chunk.multipliers = chunk.multipliers[:break_location]
|
||||
#
|
||||
# next_chunk()
|
||||
# chunk.tokens = reloc_tokens
|
||||
# chunk.multipliers = reloc_mults
|
||||
#
|
||||
# if len(chunk.tokens) == self.chunk_length:
|
||||
# next_chunk()
|
||||
#
|
||||
# embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||
# if embedding is None:
|
||||
# chunk.tokens.append(token)
|
||||
# chunk.multipliers.append(weight)
|
||||
# position += 1
|
||||
# continue
|
||||
#
|
||||
# emb_len = int(embedding.vectors)
|
||||
# if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||
# next_chunk()
|
||||
#
|
||||
# chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
||||
#
|
||||
# chunk.tokens += [0] * emb_len
|
||||
# chunk.multipliers += [weight] * emb_len
|
||||
# position += embedding_length_in_tokens
|
||||
#
|
||||
# if chunk.tokens or not chunks:
|
||||
# next_chunk(is_last=True)
|
||||
#
|
||||
# return chunks, token_count
|
||||
#
|
||||
# def process_texts(self, texts):
|
||||
# """
|
||||
# Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
||||
# length, in tokens, of all texts.
|
||||
# """
|
||||
#
|
||||
# token_count = 0
|
||||
#
|
||||
# cache = {}
|
||||
# batch_chunks = []
|
||||
# for line in texts:
|
||||
# if line in cache:
|
||||
# chunks = cache[line]
|
||||
# else:
|
||||
# chunks, current_token_count = self.tokenize_line(line)
|
||||
# token_count = max(current_token_count, token_count)
|
||||
#
|
||||
# cache[line] = chunks
|
||||
#
|
||||
# batch_chunks.append(chunks)
|
||||
#
|
||||
# return batch_chunks, token_count
|
||||
#
|
||||
# def forward(self, texts):
|
||||
# """
|
||||
# Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||
# Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||
# be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||
# An example shape returned by this function can be: (2, 77, 768).
|
||||
# For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||
# Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one element
|
||||
# is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
# """
|
||||
#
|
||||
# batch_chunks, token_count = self.process_texts(texts)
|
||||
#
|
||||
# used_embeddings = {}
|
||||
# chunk_count = max([len(x) for x in batch_chunks])
|
||||
#
|
||||
# zs = []
|
||||
# for i in range(chunk_count):
|
||||
# batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
||||
#
|
||||
# tokens = [x.tokens for x in batch_chunk]
|
||||
# multipliers = [x.multipliers for x in batch_chunk]
|
||||
# self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
#
|
||||
# for fixes in self.hijack.fixes:
|
||||
# for _position, embedding in fixes:
|
||||
# used_embeddings[embedding.name] = embedding
|
||||
# devices.torch_npu_set_device()
|
||||
# z = self.process_tokens(tokens, multipliers)
|
||||
# zs.append(z)
|
||||
#
|
||||
# if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
|
||||
# hashes = []
|
||||
# for name, embedding in used_embeddings.items():
|
||||
# shorthash = embedding.shorthash
|
||||
# if not shorthash:
|
||||
# continue
|
||||
#
|
||||
# name = name.replace(":", "").replace(",", "")
|
||||
# hashes.append(f"{name}: {shorthash}")
|
||||
#
|
||||
# if hashes:
|
||||
# if self.hijack.extra_generation_params.get("TI hashes"):
|
||||
# hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
# self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
#
|
||||
# if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
# self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
#
|
||||
# if self.return_pooled:
|
||||
# return torch.hstack(zs), zs[0].pooled
|
||||
# else:
|
||||
# return torch.hstack(zs)
|
||||
#
|
||||
# def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
# """
|
||||
# sends one single prompt chunk to be encoded by transformers neural network.
|
||||
# remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
||||
# there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
||||
# Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
||||
# corresponds to one token.
|
||||
# """
|
||||
# tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
#
|
||||
# # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
||||
# if self.id_end != self.id_pad:
|
||||
# for batch_pos in range(len(remade_batch_tokens)):
|
||||
# index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
# tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
#
|
||||
# z = self.encode_with_transformers(tokens)
|
||||
#
|
||||
# pooled = getattr(z, 'pooled', None)
|
||||
#
|
||||
# emphasis = sd_emphasis.get_current_option(opts.emphasis)()
|
||||
# emphasis.tokens = remade_batch_tokens
|
||||
# emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
# emphasis.z = z
|
||||
#
|
||||
# emphasis.after_transformers()
|
||||
#
|
||||
# z = emphasis.z
|
||||
#
|
||||
# if pooled is not None:
|
||||
# z.pooled = pooled
|
||||
#
|
||||
# return z
|
||||
#
|
||||
#
|
||||
# class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):
|
||||
# """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||
# have unlimited prompt length and assign weights to tokens in prompt.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__()
|
||||
#
|
||||
# self.hijack = hijack
|
||||
#
|
||||
# self.wrapped = wrapped
|
||||
# """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||
# depending on model."""
|
||||
#
|
||||
# self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
# self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
# self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
|
||||
#
|
||||
# self.legacy_ucg_val = None # for sgm codebase
|
||||
#
|
||||
# def forward(self, texts):
|
||||
# if opts.use_old_emphasis_implementation:
|
||||
# import modules.sd_hijack_clip_old
|
||||
# return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||
#
|
||||
# return super().forward(texts)
|
||||
#
|
||||
#
|
||||
# class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__(wrapped, hijack)
|
||||
# self.tokenizer = wrapped.tokenizer
|
||||
#
|
||||
# vocab = self.tokenizer.get_vocab()
|
||||
#
|
||||
# self.comma_token = vocab.get(',</w>', None)
|
||||
#
|
||||
# self.token_mults = {}
|
||||
# tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
# for text, ident in tokens_with_parens:
|
||||
# mult = 1.0
|
||||
# for c in text:
|
||||
# if c == '[':
|
||||
# mult /= 1.1
|
||||
# if c == ']':
|
||||
# mult *= 1.1
|
||||
# if c == '(':
|
||||
# mult *= 1.1
|
||||
# if c == ')':
|
||||
# mult /= 1.1
|
||||
#
|
||||
# if mult != 1.0:
|
||||
# self.token_mults[ident] = mult
|
||||
#
|
||||
# self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
# self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
# self.id_pad = self.id_end
|
||||
#
|
||||
# def tokenize(self, texts):
|
||||
# tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
#
|
||||
# return tokenized
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
#
|
||||
# if opts.CLIP_stop_at_last_layers > 1:
|
||||
# z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
# z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
# else:
|
||||
# z = outputs.last_hidden_state
|
||||
#
|
||||
# return z
|
||||
#
|
||||
# def encode_embedding_init_text(self, init_text, nvpt):
|
||||
# embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
# ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
# embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
#
|
||||
# return embedded
|
||||
#
|
||||
#
|
||||
# class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__(wrapped, hijack)
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
|
||||
#
|
||||
# if opts.sdxl_clip_l_skip is True:
|
||||
# z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
# elif self.wrapped.layer == "last":
|
||||
# z = outputs.last_hidden_state
|
||||
# else:
|
||||
# z = outputs.hidden_states[self.wrapped.layer_idx]
|
||||
#
|
||||
# return z
|
||||
82
modules/sd_hijack_clip_old.py
Executable file
82
modules/sd_hijack_clip_old.py
Executable file
@@ -0,0 +1,82 @@
|
||||
# from modules import sd_hijack_clip
|
||||
# from modules import shared
|
||||
#
|
||||
#
|
||||
# def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||
# id_start = self.id_start
|
||||
# id_end = self.id_end
|
||||
# maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
# used_custom_terms = []
|
||||
# remade_batch_tokens = []
|
||||
# hijack_comments = []
|
||||
# hijack_fixes = []
|
||||
# token_count = 0
|
||||
#
|
||||
# cache = {}
|
||||
# batch_tokens = self.tokenize(texts)
|
||||
# batch_multipliers = []
|
||||
# for tokens in batch_tokens:
|
||||
# tuple_tokens = tuple(tokens)
|
||||
#
|
||||
# if tuple_tokens in cache:
|
||||
# remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
# else:
|
||||
# fixes = []
|
||||
# remade_tokens = []
|
||||
# multipliers = []
|
||||
# mult = 1.0
|
||||
#
|
||||
# i = 0
|
||||
# while i < len(tokens):
|
||||
# token = tokens[i]
|
||||
#
|
||||
# embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
#
|
||||
# mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
|
||||
# if mult_change is not None:
|
||||
# mult *= mult_change
|
||||
# i += 1
|
||||
# elif embedding is None:
|
||||
# remade_tokens.append(token)
|
||||
# multipliers.append(mult)
|
||||
# i += 1
|
||||
# else:
|
||||
# emb_len = int(embedding.vec.shape[0])
|
||||
# fixes.append((len(remade_tokens), embedding))
|
||||
# remade_tokens += [0] * emb_len
|
||||
# multipliers += [mult] * emb_len
|
||||
# used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
# i += embedding_length_in_tokens
|
||||
#
|
||||
# if len(remade_tokens) > maxlen - 2:
|
||||
# vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
# ovf = remade_tokens[maxlen - 2:]
|
||||
# overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
# overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
# hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
#
|
||||
# token_count = len(remade_tokens)
|
||||
# remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
# remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
# cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
#
|
||||
# multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
# multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
#
|
||||
# remade_batch_tokens.append(remade_tokens)
|
||||
# hijack_fixes.append(fixes)
|
||||
# batch_multipliers.append(multipliers)
|
||||
# return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
#
|
||||
#
|
||||
# def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||
# batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
||||
#
|
||||
# self.hijack.comments += hijack_comments
|
||||
#
|
||||
# if used_custom_terms:
|
||||
# embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||
# self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||
#
|
||||
# self.hijack.fixes = hijack_fixes
|
||||
# return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
10
modules/sd_hijack_ip2p.py
Executable file
10
modules/sd_hijack_ip2p.py
Executable file
@@ -0,0 +1,10 @@
|
||||
# import os.path
|
||||
#
|
||||
#
|
||||
# def should_hijack_ip2p(checkpoint_info):
|
||||
# from modules import sd_models_config
|
||||
#
|
||||
# ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
# cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||
#
|
||||
# return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||
71
modules/sd_hijack_open_clip.py
Executable file
71
modules/sd_hijack_open_clip.py
Executable file
@@ -0,0 +1,71 @@
|
||||
# import open_clip.tokenizer
|
||||
# import torch
|
||||
#
|
||||
# from modules import sd_hijack_clip, devices
|
||||
# from modules.shared import opts
|
||||
#
|
||||
# tokenizer = open_clip.tokenizer._tokenizer
|
||||
#
|
||||
#
|
||||
# class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__(wrapped, hijack)
|
||||
#
|
||||
# self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
# self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
# self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
# self.id_pad = 0
|
||||
#
|
||||
# def tokenize(self, texts):
|
||||
# assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
#
|
||||
# tokenized = [tokenizer.encode(text) for text in texts]
|
||||
#
|
||||
# return tokenized
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
# z = self.wrapped.encode_with_transformer(tokens)
|
||||
#
|
||||
# return z
|
||||
#
|
||||
# def encode_embedding_init_text(self, init_text, nvpt):
|
||||
# ids = tokenizer.encode(init_text)
|
||||
# ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
# embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
#
|
||||
# return embedded
|
||||
#
|
||||
#
|
||||
# class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__(wrapped, hijack)
|
||||
#
|
||||
# self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
# self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
# self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
# self.id_pad = 0
|
||||
#
|
||||
# def tokenize(self, texts):
|
||||
# assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
#
|
||||
# tokenized = [tokenizer.encode(text) for text in texts]
|
||||
#
|
||||
# return tokenized
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# d = self.wrapped.encode_with_transformer(tokens)
|
||||
# z = d[self.wrapped.layer]
|
||||
#
|
||||
# pooled = d.get("pooled")
|
||||
# if pooled is not None:
|
||||
# z.pooled = pooled
|
||||
#
|
||||
# return z
|
||||
#
|
||||
# def encode_embedding_init_text(self, init_text, nvpt):
|
||||
# ids = tokenizer.encode(init_text)
|
||||
# ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
# embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
#
|
||||
# return embedded
|
||||
677
modules/sd_hijack_optimizations.py
Executable file
677
modules/sd_hijack_optimizations.py
Executable file
@@ -0,0 +1,677 @@
|
||||
# from __future__ import annotations
|
||||
# import math
|
||||
# import psutil
|
||||
# import platform
|
||||
#
|
||||
# import torch
|
||||
# from torch import einsum
|
||||
#
|
||||
# from ldm.util import default
|
||||
# from einops import rearrange
|
||||
#
|
||||
# from modules import shared, errors, devices, sub_quadratic_attention
|
||||
# from modules.hypernetworks import hypernetwork
|
||||
#
|
||||
# import ldm.modules.attention
|
||||
# import ldm.modules.diffusionmodules.model
|
||||
#
|
||||
# import sgm.modules.attention
|
||||
# import sgm.modules.diffusionmodules.model
|
||||
#
|
||||
# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
# sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
#
|
||||
#
|
||||
# class SdOptimization:
|
||||
# name: str = None
|
||||
# label: str | None = None
|
||||
# cmd_opt: str | None = None
|
||||
# priority: int = 0
|
||||
#
|
||||
# def title(self):
|
||||
# if self.label is None:
|
||||
# return self.name
|
||||
#
|
||||
# return f"{self.name} - {self.label}"
|
||||
#
|
||||
# def is_available(self):
|
||||
# return True
|
||||
#
|
||||
# def apply(self):
|
||||
# pass
|
||||
#
|
||||
# def undo(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
#
|
||||
# sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||
#
|
||||
#
|
||||
# class SdOptimizationXformers(SdOptimization):
|
||||
# name = "xformers"
|
||||
# cmd_opt = "xformers"
|
||||
# priority = 100
|
||||
#
|
||||
# def is_available(self):
|
||||
# return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
# sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
#
|
||||
#
|
||||
# class SdOptimizationSdpNoMem(SdOptimization):
|
||||
# name = "sdp-no-mem"
|
||||
# label = "scaled dot product without memory efficient attention"
|
||||
# cmd_opt = "opt_sdp_no_mem_attention"
|
||||
# priority = 80
|
||||
#
|
||||
# def is_available(self):
|
||||
# return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
#
|
||||
#
|
||||
# class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
# name = "sdp"
|
||||
# label = "scaled dot product"
|
||||
# cmd_opt = "opt_sdp_attention"
|
||||
# priority = 70
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
#
|
||||
#
|
||||
# class SdOptimizationSubQuad(SdOptimization):
|
||||
# name = "sub-quadratic"
|
||||
# cmd_opt = "opt_sub_quad_attention"
|
||||
#
|
||||
# @property
|
||||
# def priority(self):
|
||||
# return 1000 if shared.device.type == 'mps' else 10
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
# sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
#
|
||||
#
|
||||
# class SdOptimizationV1(SdOptimization):
|
||||
# name = "V1"
|
||||
# label = "original v1"
|
||||
# cmd_opt = "opt_split_attention_v1"
|
||||
# priority = 10
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
#
|
||||
#
|
||||
# class SdOptimizationInvokeAI(SdOptimization):
|
||||
# name = "InvokeAI"
|
||||
# cmd_opt = "opt_split_attention_invokeai"
|
||||
#
|
||||
# @property
|
||||
# def priority(self):
|
||||
# return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
#
|
||||
#
|
||||
# class SdOptimizationDoggettx(SdOptimization):
|
||||
# name = "Doggettx"
|
||||
# cmd_opt = "opt_split_attention"
|
||||
# priority = 90
|
||||
#
|
||||
# def apply(self):
|
||||
# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
# ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
# sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
#
|
||||
#
|
||||
# def list_optimizers(res):
|
||||
# res.extend([
|
||||
# SdOptimizationXformers(),
|
||||
# SdOptimizationSdpNoMem(),
|
||||
# SdOptimizationSdp(),
|
||||
# SdOptimizationSubQuad(),
|
||||
# SdOptimizationV1(),
|
||||
# SdOptimizationInvokeAI(),
|
||||
# SdOptimizationDoggettx(),
|
||||
# ])
|
||||
#
|
||||
#
|
||||
# if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
# try:
|
||||
# import xformers.ops
|
||||
# shared.xformers_available = True
|
||||
# except Exception:
|
||||
# errors.report("Cannot import xformers", exc_info=True)
|
||||
#
|
||||
#
|
||||
# def get_available_vram():
|
||||
# if shared.device.type == 'cuda':
|
||||
# stats = torch.cuda.memory_stats(shared.device)
|
||||
# mem_active = stats['active_bytes.all.current']
|
||||
# mem_reserved = stats['reserved_bytes.all.current']
|
||||
# mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
# mem_free_torch = mem_reserved - mem_active
|
||||
# mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# return mem_free_total
|
||||
# else:
|
||||
# return psutil.virtual_memory().available
|
||||
#
|
||||
#
|
||||
# # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
# def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
||||
# h = self.heads
|
||||
#
|
||||
# q_in = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
# del context, context_k, context_v, x
|
||||
#
|
||||
# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
#
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k, v = q.float(), k.float(), v.float()
|
||||
#
|
||||
# with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
# for i in range(0, q.shape[0], 2):
|
||||
# end = i + 2
|
||||
# s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
# s1 *= self.scale
|
||||
#
|
||||
# s2 = s1.softmax(dim=-1)
|
||||
# del s1
|
||||
#
|
||||
# r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
# del s2
|
||||
# del q, k, v
|
||||
#
|
||||
# r1 = r1.to(dtype)
|
||||
#
|
||||
# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
# del r1
|
||||
#
|
||||
# return self.to_out(r2)
|
||||
#
|
||||
#
|
||||
# # taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||
# def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
# h = self.heads
|
||||
#
|
||||
# q_in = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
#
|
||||
# dtype = q_in.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||
#
|
||||
# with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
# k_in = k_in * self.scale
|
||||
#
|
||||
# del context, x
|
||||
#
|
||||
# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
#
|
||||
# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
#
|
||||
# mem_free_total = get_available_vram()
|
||||
#
|
||||
# gb = 1024 ** 3
|
||||
# tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
# modifier = 3 if q.element_size() == 2 else 2.5
|
||||
# mem_required = tensor_size * modifier
|
||||
# steps = 1
|
||||
#
|
||||
# if mem_required > mem_free_total:
|
||||
# steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
#
|
||||
# if steps > 64:
|
||||
# max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
# raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
# f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
#
|
||||
# slice_size = q.shape[1] // steps
|
||||
# for i in range(0, q.shape[1], slice_size):
|
||||
# end = min(i + slice_size, q.shape[1])
|
||||
# s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||
#
|
||||
# s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
# del s1
|
||||
#
|
||||
# r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
# del s2
|
||||
#
|
||||
# del q, k, v
|
||||
#
|
||||
# r1 = r1.to(dtype)
|
||||
#
|
||||
# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
# del r1
|
||||
#
|
||||
# return self.to_out(r2)
|
||||
#
|
||||
#
|
||||
# # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
# mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
#
|
||||
#
|
||||
# def einsum_op_compvis(q, k, v):
|
||||
# s = einsum('b i d, b j d -> b i j', q, k)
|
||||
# s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
# return einsum('b i j, b j d -> b i d', s, v)
|
||||
#
|
||||
#
|
||||
# def einsum_op_slice_0(q, k, v, slice_size):
|
||||
# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
# for i in range(0, q.shape[0], slice_size):
|
||||
# end = i + slice_size
|
||||
# r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
# return r
|
||||
#
|
||||
#
|
||||
# def einsum_op_slice_1(q, k, v, slice_size):
|
||||
# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
# for i in range(0, q.shape[1], slice_size):
|
||||
# end = i + slice_size
|
||||
# r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
# return r
|
||||
#
|
||||
#
|
||||
# def einsum_op_mps_v1(q, k, v):
|
||||
# if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
# return einsum_op_compvis(q, k, v)
|
||||
# else:
|
||||
# slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
# if slice_size % 4096 == 0:
|
||||
# slice_size -= 1
|
||||
# return einsum_op_slice_1(q, k, v, slice_size)
|
||||
#
|
||||
#
|
||||
# def einsum_op_mps_v2(q, k, v):
|
||||
# if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
# return einsum_op_compvis(q, k, v)
|
||||
# else:
|
||||
# return einsum_op_slice_0(q, k, v, 1)
|
||||
#
|
||||
#
|
||||
# def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
# size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
# if size_mb <= max_tensor_mb:
|
||||
# return einsum_op_compvis(q, k, v)
|
||||
# div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
# if div <= q.shape[0]:
|
||||
# return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
# return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
#
|
||||
#
|
||||
# def einsum_op_cuda(q, k, v):
|
||||
# stats = torch.cuda.memory_stats(q.device)
|
||||
# mem_active = stats['active_bytes.all.current']
|
||||
# mem_reserved = stats['reserved_bytes.all.current']
|
||||
# mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||
# mem_free_torch = mem_reserved - mem_active
|
||||
# mem_free_total = mem_free_cuda + mem_free_torch
|
||||
# # Divide factor of safety as there's copying and fragmentation
|
||||
# return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
#
|
||||
#
|
||||
# def einsum_op(q, k, v):
|
||||
# if q.device.type == 'cuda':
|
||||
# return einsum_op_cuda(q, k, v)
|
||||
#
|
||||
# if q.device.type == 'mps':
|
||||
# if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||
# return einsum_op_mps_v1(q, k, v)
|
||||
# return einsum_op_mps_v2(q, k, v)
|
||||
#
|
||||
# # Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# # Tested on i7 with 8MB L3 cache.
|
||||
# return einsum_op_tensor_mem(q, k, v, 32)
|
||||
#
|
||||
#
|
||||
# def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
||||
# h = self.heads
|
||||
#
|
||||
# q = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k = self.to_k(context_k)
|
||||
# v = self.to_v(context_v)
|
||||
# del context, context_k, context_v, x
|
||||
#
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||
#
|
||||
# with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||
# k = k * self.scale
|
||||
#
|
||||
# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||
# r = einsum_op(q, k, v)
|
||||
# r = r.to(dtype)
|
||||
# return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
#
|
||||
# # -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
#
|
||||
#
|
||||
# # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||
# # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||
# def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
# assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||
#
|
||||
# h = self.heads
|
||||
#
|
||||
# q = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k = self.to_k(context_k)
|
||||
# v = self.to_v(context_v)
|
||||
# del context, context_k, context_v, x
|
||||
#
|
||||
# q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
# k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
# v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
#
|
||||
# if q.device.type == 'mps':
|
||||
# q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||
#
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k = q.float(), k.float()
|
||||
#
|
||||
# x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
#
|
||||
# x = x.to(dtype)
|
||||
#
|
||||
# x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||
#
|
||||
# out_proj, dropout = self.to_out
|
||||
# x = out_proj(x)
|
||||
# x = dropout(x)
|
||||
#
|
||||
# return x
|
||||
#
|
||||
#
|
||||
# def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
# bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
# batch_x_heads, q_tokens, _ = q.shape
|
||||
# _, k_tokens, _ = k.shape
|
||||
# qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
#
|
||||
# if chunk_threshold is None:
|
||||
# if q.device.type == 'mps':
|
||||
# chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||
# else:
|
||||
# chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||
# elif chunk_threshold == 0:
|
||||
# chunk_threshold_bytes = None
|
||||
# else:
|
||||
# chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||
#
|
||||
# if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||
# kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||
# elif kv_chunk_size_min == 0:
|
||||
# kv_chunk_size_min = None
|
||||
#
|
||||
# if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# # the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# # i.e. send it down the unchunked fast-path
|
||||
# kv_chunk_size = k_tokens
|
||||
#
|
||||
# with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||
# return sub_quadratic_attention.efficient_dot_product_attention(
|
||||
# q,
|
||||
# k,
|
||||
# v,
|
||||
# query_chunk_size=q_chunk_size,
|
||||
# kv_chunk_size=kv_chunk_size,
|
||||
# kv_chunk_size_min = kv_chunk_size_min,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# )
|
||||
#
|
||||
#
|
||||
# def get_xformers_flash_attention_op(q, k, v):
|
||||
# if not shared.cmd_opts.xformers_flash_attention:
|
||||
# return None
|
||||
#
|
||||
# try:
|
||||
# flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||
# fw, bw = flash_attention_op
|
||||
# if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||
# return flash_attention_op
|
||||
# except Exception as e:
|
||||
# errors.display_once(e, "enabling flash attention")
|
||||
#
|
||||
# return None
|
||||
#
|
||||
#
|
||||
# def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
#
|
||||
# q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))
|
||||
#
|
||||
# del q_in, k_in, v_in
|
||||
#
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k, v = q.float(), k.float(), v.float()
|
||||
#
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||
#
|
||||
# out = out.to(dtype)
|
||||
#
|
||||
# b, n, h, d = out.shape
|
||||
# out = out.reshape(b, n, h * d)
|
||||
# return self.to_out(out)
|
||||
#
|
||||
#
|
||||
# # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||
# # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||
# def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
# batch_size, sequence_length, inner_dim = x.shape
|
||||
#
|
||||
# if mask is not None:
|
||||
# mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
||||
# mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
||||
#
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
# context = default(context, x)
|
||||
#
|
||||
# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
#
|
||||
# head_dim = inner_dim // h
|
||||
# q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
# k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
# v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
#
|
||||
# del q_in, k_in, v_in
|
||||
#
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k, v = q.float(), k.float(), v.float()
|
||||
#
|
||||
# # the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
# q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
# )
|
||||
#
|
||||
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
||||
# hidden_states = hidden_states.to(dtype)
|
||||
#
|
||||
# # linear proj
|
||||
# hidden_states = self.to_out[0](hidden_states)
|
||||
# # dropout
|
||||
# hidden_states = self.to_out[1](hidden_states)
|
||||
# return hidden_states
|
||||
#
|
||||
#
|
||||
# def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
# return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||
#
|
||||
#
|
||||
# def cross_attention_attnblock_forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q1 = self.q(h_)
|
||||
# k1 = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
#
|
||||
# # compute attention
|
||||
# b, c, h, w = q1.shape
|
||||
#
|
||||
# q2 = q1.reshape(b, c, h*w)
|
||||
# del q1
|
||||
#
|
||||
# q = q2.permute(0, 2, 1) # b,hw,c
|
||||
# del q2
|
||||
#
|
||||
# k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
# del k1
|
||||
#
|
||||
# h_ = torch.zeros_like(k, device=q.device)
|
||||
#
|
||||
# mem_free_total = get_available_vram()
|
||||
#
|
||||
# tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
# mem_required = tensor_size * 2.5
|
||||
# steps = 1
|
||||
#
|
||||
# if mem_required > mem_free_total:
|
||||
# steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
#
|
||||
# slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
# for i in range(0, q.shape[1], slice_size):
|
||||
# end = i + slice_size
|
||||
#
|
||||
# w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
# w2 = w1 * (int(c)**(-0.5))
|
||||
# del w1
|
||||
# w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||
# del w2
|
||||
#
|
||||
# # attend to values
|
||||
# v1 = v.reshape(b, c, h*w)
|
||||
# w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
# del w3
|
||||
#
|
||||
# h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
# del v1, w4
|
||||
#
|
||||
# h2 = h_.reshape(b, c, h, w)
|
||||
# del h_
|
||||
#
|
||||
# h3 = self.proj_out(h2)
|
||||
# del h2
|
||||
#
|
||||
# h3 += x
|
||||
#
|
||||
# return h3
|
||||
#
|
||||
#
|
||||
# def xformers_attnblock_forward(self, x):
|
||||
# try:
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
# b, c, h, w = q.shape
|
||||
# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k = q.float(), k.float()
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||
# out = out.to(dtype)
|
||||
# out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
# out = self.proj_out(out)
|
||||
# return x + out
|
||||
# except NotImplementedError:
|
||||
# return cross_attention_attnblock_forward(self, x)
|
||||
#
|
||||
#
|
||||
# def sdp_attnblock_forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
# b, c, h, w = q.shape
|
||||
# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
# dtype = q.dtype
|
||||
# if shared.opts.upcast_attn:
|
||||
# q, k, v = q.float(), k.float(), v.float()
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
||||
# out = out.to(dtype)
|
||||
# out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
# out = self.proj_out(out)
|
||||
# return x + out
|
||||
#
|
||||
#
|
||||
# def sdp_no_mem_attnblock_forward(self, x):
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
# return sdp_attnblock_forward(self, x)
|
||||
#
|
||||
#
|
||||
# def sub_quad_attnblock_forward(self, x):
|
||||
# h_ = x
|
||||
# h_ = self.norm(h_)
|
||||
# q = self.q(h_)
|
||||
# k = self.k(h_)
|
||||
# v = self.v(h_)
|
||||
# b, c, h, w = q.shape
|
||||
# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
# out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
# out = self.proj_out(out)
|
||||
# return x + out
|
||||
154
modules/sd_hijack_unet.py
Executable file
154
modules/sd_hijack_unet.py
Executable file
@@ -0,0 +1,154 @@
|
||||
# import torch
|
||||
# from packaging import version
|
||||
# from einops import repeat
|
||||
# import math
|
||||
#
|
||||
# from modules import devices
|
||||
# from modules.sd_hijack_utils import CondFunc
|
||||
#
|
||||
#
|
||||
# class TorchHijackForUnet:
|
||||
# """
|
||||
# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
# """
|
||||
#
|
||||
# def __getattr__(self, item):
|
||||
# if item == 'cat':
|
||||
# return self.cat
|
||||
#
|
||||
# if hasattr(torch, item):
|
||||
# return getattr(torch, item)
|
||||
#
|
||||
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
#
|
||||
# def cat(self, tensors, *args, **kwargs):
|
||||
# if len(tensors) == 2:
|
||||
# a, b = tensors
|
||||
# if a.shape[-2:] != b.shape[-2:]:
|
||||
# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
#
|
||||
# tensors = (a, b)
|
||||
#
|
||||
# return torch.cat(tensors, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# th = TorchHijackForUnet()
|
||||
#
|
||||
#
|
||||
# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||
# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
# """Always make sure inputs to unet are in correct dtype."""
|
||||
# if isinstance(cond, dict):
|
||||
# for y in cond.keys():
|
||||
# if isinstance(cond[y], list):
|
||||
# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
# else:
|
||||
# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
#
|
||||
# with devices.autocast():
|
||||
# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
# if devices.unet_needs_upcast:
|
||||
# return result.float()
|
||||
# else:
|
||||
# return result
|
||||
#
|
||||
#
|
||||
# # Monkey patch to create timestep embed tensor on device, avoiding a block.
|
||||
# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
||||
# """
|
||||
# Create sinusoidal timestep embeddings.
|
||||
# :param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
# These may be fractional.
|
||||
# :param dim: the dimension of the output.
|
||||
# :param max_period: controls the minimum frequency of the embeddings.
|
||||
# :return: an [N x dim] Tensor of positional embeddings.
|
||||
# """
|
||||
# if not repeat_only:
|
||||
# half = dim // 2
|
||||
# freqs = torch.exp(
|
||||
# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
# )
|
||||
# args = timesteps[:, None].float() * freqs[None]
|
||||
# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
# if dim % 2:
|
||||
# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
# else:
|
||||
# embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
# return embedding
|
||||
#
|
||||
#
|
||||
# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||
# # Prevents a lot of unnecessary aten::copy_ calls
|
||||
# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||
# # note: if no context is given, cross-attention defaults to self-attention
|
||||
# if not isinstance(context, list):
|
||||
# context = [context]
|
||||
# b, c, h, w = x.shape
|
||||
# x_in = x
|
||||
# x = self.norm(x)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
# if self.use_linear:
|
||||
# x = self.proj_in(x)
|
||||
# for i, block in enumerate(self.transformer_blocks):
|
||||
# x = block(x, context=context[i])
|
||||
# if self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||
# if not self.use_linear:
|
||||
# x = self.proj_out(x)
|
||||
# return x + x_in
|
||||
#
|
||||
#
|
||||
# class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
# def forward(self, x):
|
||||
# if devices.unet_needs_upcast:
|
||||
# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
# else:
|
||||
# return torch.nn.GELU.forward(self, x)
|
||||
#
|
||||
#
|
||||
# ddpm_edit_hijack = None
|
||||
# def hijack_ddpm_edit():
|
||||
# global ddpm_edit_hijack
|
||||
# if not ddpm_edit_hijack:
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
|
||||
#
|
||||
#
|
||||
# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
#
|
||||
# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
#
|
||||
# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||
#
|
||||
# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
|
||||
# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
|
||||
#
|
||||
#
|
||||
# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
# if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
# dtype = torch.float32
|
||||
# else:
|
||||
# dtype = devices.dtype_unet
|
||||
# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
#
|
||||
#
|
||||
# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
|
||||
36
modules/sd_hijack_utils.py
Executable file
36
modules/sd_hijack_utils.py
Executable file
@@ -0,0 +1,36 @@
|
||||
import importlib
|
||||
|
||||
|
||||
always_true_func = lambda *args, **kwargs: True
|
||||
|
||||
|
||||
class CondFunc:
|
||||
def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
|
||||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
except AttributeError:
|
||||
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
|
||||
pass
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
self.__orig_func = orig_func
|
||||
self.__sub_func = sub_func
|
||||
self.__cond_func = cond_func
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||
else:
|
||||
return self.__orig_func(*args, **kwargs)
|
||||
32
modules/sd_hijack_xlmr.py
Executable file
32
modules/sd_hijack_xlmr.py
Executable file
@@ -0,0 +1,32 @@
|
||||
# import torch
|
||||
#
|
||||
# from modules import sd_hijack_clip, devices
|
||||
#
|
||||
#
|
||||
# class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
# def __init__(self, wrapped, hijack):
|
||||
# super().__init__(wrapped, hijack)
|
||||
#
|
||||
# self.id_start = wrapped.config.bos_token_id
|
||||
# self.id_end = wrapped.config.eos_token_id
|
||||
# self.id_pad = wrapped.config.pad_token_id
|
||||
#
|
||||
# self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
#
|
||||
# def encode_with_transformers(self, tokens):
|
||||
# # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# # layer to work with - you have to use the last
|
||||
#
|
||||
# attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
# features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
# z = features['projection_state']
|
||||
#
|
||||
# return z
|
||||
#
|
||||
# def encode_embedding_init_text(self, init_text, nvpt):
|
||||
# embedding_layer = self.wrapped.roberta.embeddings
|
||||
# ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
# embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
#
|
||||
# return embedded
|
||||
526
modules/sd_models.py
Executable file
526
modules/sd_models.py
Executable file
@@ -0,0 +1,526 @@
|
||||
import collections
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import threading
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from urllib import request
|
||||
import gc
|
||||
import contextlib
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.timer import Timer
|
||||
import numpy as np
|
||||
from backend.loader import forge_loader
|
||||
from backend import memory_management
|
||||
from backend.args import dynamic_args
|
||||
from backend.utils import load_torch_file
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
|
||||
checkpoints_list = {}
|
||||
checkpoint_aliases = {}
|
||||
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
SD1 = 1
|
||||
SD2 = 2
|
||||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
keys = list(d.keys())
|
||||
|
||||
d[new_key] = value
|
||||
|
||||
if key not in keys:
|
||||
return d
|
||||
|
||||
index = keys.index(key)
|
||||
keys[index] = new_key
|
||||
|
||||
new_d = {k: d[k] for k in keys}
|
||||
|
||||
d.clear()
|
||||
d.update(new_d)
|
||||
return d
|
||||
|
||||
|
||||
class CheckpointInfo:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
|
||||
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
|
||||
name = abspath.replace(abs_ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
name = os.path.basename(filename)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
def read_metadata():
|
||||
metadata = read_metadata_from_safetensors(filename)
|
||||
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||
|
||||
return metadata
|
||||
|
||||
self.metadata = {}
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading metadata for {filename}")
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
self.hash = model_hash(filename)
|
||||
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
|
||||
if self.shorthash:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
for id in self.ids:
|
||||
checkpoint_aliases[id] = self
|
||||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
shorthash = self.sha256[0:10]
|
||||
if self.shorthash == self.sha256[0:10]:
|
||||
return self.shorthash
|
||||
|
||||
self.shorthash = shorthash
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||
|
||||
old_title = self.title
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||
|
||||
replace_key(checkpoints_list, old_title, self.title, self)
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
|
||||
def __str__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
def __repr__(self):
|
||||
return str(dict(filename=self.filename, hash=self.hash))
|
||||
|
||||
|
||||
# try:
|
||||
# # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
# from transformers import logging, CLIPModel # noqa: F401
|
||||
#
|
||||
# logging.set_verbosity_error()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""called once at startup to do various one-time tasks related to SD models"""
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
enable_midas_autodownload()
|
||||
patch_given_betas()
|
||||
|
||||
|
||||
def checkpoint_tiles(use_short=False):
|
||||
return [x.short_title if use_short else x.name for x in checkpoints_list.values()]
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_aliases.clear()
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors", ".gguf"], download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
|
||||
if os.path.exists(cmd_ckpt):
|
||||
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||
checkpoint_info.register()
|
||||
|
||||
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in model_list:
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
|
||||
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||
|
||||
def match_checkpoint_to_name(name):
|
||||
name = name.split(' [')[0]
|
||||
|
||||
for ckptname in checkpoints_list.values():
|
||||
title = ckptname.title.split(' [')[0]
|
||||
if (name in title) or (title in name):
|
||||
return ckptname.short_title if shared.opts.sd_checkpoint_dropdown_use_short else ckptname.name.split(' [')[0]
|
||||
|
||||
return name
|
||||
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
if not search_string:
|
||||
return None
|
||||
|
||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
||||
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
||||
|
||||
def select_checkpoint():
|
||||
"""Raises `FileNotFoundError` if no checkpoints are found."""
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
if len(checkpoints_list) == 0:
|
||||
print('You do not have any model!')
|
||||
return None
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
if model_checkpoint is not None:
|
||||
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
||||
|
||||
return checkpoint_info
|
||||
|
||||
|
||||
def transform_checkpoint_dict_key(k, replacements):
|
||||
pass
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
pass
|
||||
|
||||
|
||||
def read_metadata_from_safetensors(filename):
|
||||
import json
|
||||
|
||||
with open(filename, mode="rb") as file:
|
||||
metadata_len = file.read(8)
|
||||
metadata_len = int.from_bytes(metadata_len, "little")
|
||||
json_start = file.read(2)
|
||||
|
||||
assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
|
||||
|
||||
res = {}
|
||||
|
||||
try:
|
||||
json_data = json_start + file.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f"Error reading metadata from file: {filename}", exc_info=True)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
pass
|
||||
|
||||
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
# move to end as latest
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = load_torch_file(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def SkipWritingToConfig():
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
pass
|
||||
|
||||
|
||||
def set_model_type(model, state_dict):
|
||||
pass
|
||||
|
||||
|
||||
def set_model_fields(model):
|
||||
pass
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
pass
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
pass
|
||||
|
||||
|
||||
def patch_given_betas():
|
||||
pass
|
||||
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
pass
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def apply_alpha_schedule_override(sd_model, p=None):
|
||||
"""
|
||||
Applies an override to the alpha schedule of the model according to settings.
|
||||
- downcasts the alpha schedule to half precision
|
||||
- rescales the alpha schedule to have zero terminal SNR
|
||||
"""
|
||||
|
||||
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
return
|
||||
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
|
||||
# This is a dummy class for backward compatibility when model is not load - for extensions like prompt all in one.
|
||||
class FakeInitialModel:
|
||||
def __init__(self):
|
||||
self.cond_stage_model = None
|
||||
self.chunk_length = 75
|
||||
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
r = len(prompt.strip('!,. ').replace(' ', ',').replace('.', ',').replace('!', ',').replace(',,', ',').replace(',,', ',').replace(',,', ',').replace(',,', ',').split(','))
|
||||
return r, math.ceil(max(r, 1) / self.chunk_length) * self.chunk_length
|
||||
|
||||
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = FakeInitialModel()
|
||||
self.forge_loading_parameters = {}
|
||||
self.forge_hash = ''
|
||||
|
||||
def get_sd_model(self):
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v):
|
||||
self.sd_model = v
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
pass
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
pass
|
||||
|
||||
|
||||
def model_target_device(m):
|
||||
return devices.device
|
||||
|
||||
|
||||
def send_model_to_device(m):
|
||||
pass
|
||||
|
||||
|
||||
def send_model_to_trash(m):
|
||||
pass
|
||||
|
||||
|
||||
def instantiate_from_config(config, state_dict=None):
|
||||
pass
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
pass
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
pass
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
pass
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
pass
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
memory_management.unload_all_models()
|
||||
return
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
if token_merging_ratio <= 0:
|
||||
return
|
||||
|
||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||
|
||||
from backend.misc.tomesd import TomePatcher
|
||||
|
||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||
model=sd_model.forge_objects.unet,
|
||||
ratio=token_merging_ratio
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forge_model_reload():
|
||||
current_hash = str(model_data.forge_loading_parameters)
|
||||
|
||||
if model_data.forge_hash == current_hash:
|
||||
return model_data.sd_model, False
|
||||
|
||||
print('Loading Model: ' + str(model_data.forge_loading_parameters))
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model = None
|
||||
memory_management.unload_all_models()
|
||||
memory_management.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
timer.record("unload existing model")
|
||||
|
||||
checkpoint_info = model_data.forge_loading_parameters['checkpoint_info']
|
||||
|
||||
if checkpoint_info is None:
|
||||
raise ValueError('You do not have any model! Please download at least one model in [models/Stable-diffusion].')
|
||||
|
||||
state_dict = checkpoint_info.filename
|
||||
additional_state_dicts = model_data.forge_loading_parameters.get('additional_modules', [])
|
||||
|
||||
timer.record("cache state dict")
|
||||
|
||||
dynamic_args['forge_unet_storage_dtype'] = model_data.forge_loading_parameters.get('unet_storage_dtype', None)
|
||||
dynamic_args['embedding_dir'] = cmd_opts.embeddings_dir
|
||||
dynamic_args['emphasis_name'] = opts.emphasis
|
||||
sd_model = forge_loader(state_dict, additional_state_dicts=additional_state_dicts)
|
||||
timer.record("forge model load")
|
||||
|
||||
sd_model.extra_generation_params = {}
|
||||
sd_model.comments = []
|
||||
sd_model.sd_checkpoint_info = checkpoint_info
|
||||
sd_model.filename = checkpoint_info.filename
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
model_data.forge_hash = current_hash
|
||||
|
||||
return sd_model, True
|
||||
137
modules/sd_models_config.py
Executable file
137
modules/sd_models_config.py
Executable file
@@ -0,0 +1,137 @@
|
||||
# import os
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# from modules import shared, paths, sd_disable_initialization, devices
|
||||
#
|
||||
# sd_configs_path = shared.sd_configs_path
|
||||
# # sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
# # sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
#
|
||||
#
|
||||
# config_default = shared.sd_default_config
|
||||
# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
#
|
||||
#
|
||||
# def is_using_v_parameterization_for_sd2(state_dict):
|
||||
# """
|
||||
# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
|
||||
# """
|
||||
#
|
||||
# import ldm.modules.diffusionmodules.openaimodel
|
||||
#
|
||||
# device = devices.device
|
||||
#
|
||||
# with sd_disable_initialization.DisableInitialization():
|
||||
# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
|
||||
# use_checkpoint=False,
|
||||
# use_fp16=False,
|
||||
# image_size=32,
|
||||
# in_channels=4,
|
||||
# out_channels=4,
|
||||
# model_channels=320,
|
||||
# attention_resolutions=[4, 2, 1],
|
||||
# num_res_blocks=2,
|
||||
# channel_mult=[1, 2, 4, 4],
|
||||
# num_head_channels=64,
|
||||
# use_spatial_transformer=True,
|
||||
# use_linear_in_transformer=True,
|
||||
# transformer_depth=1,
|
||||
# context_dim=1024,
|
||||
# legacy=False
|
||||
# )
|
||||
# unet.eval()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
||||
# unet.load_state_dict(unet_sd, strict=True)
|
||||
# unet.to(device=device, dtype=devices.dtype_unet)
|
||||
#
|
||||
# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
||||
# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
||||
#
|
||||
# with devices.autocast():
|
||||
# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
||||
#
|
||||
# return out < -1
|
||||
#
|
||||
#
|
||||
# def guess_model_config_from_state_dict(sd, filename):
|
||||
# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
#
|
||||
# if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
# return config_sd3
|
||||
#
|
||||
# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sdxl_inpainting
|
||||
# else:
|
||||
# return config_sdxl
|
||||
#
|
||||
# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
# return config_sdxl_refiner
|
||||
# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
# return config_depth_model
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
# return config_unclip
|
||||
# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||
# return config_unopenclip
|
||||
#
|
||||
# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_sd2_inpainting
|
||||
# # elif is_using_v_parameterization_for_sd2(sd):
|
||||
# # return config_sd2v
|
||||
# else:
|
||||
# return config_sd2v
|
||||
#
|
||||
# if diffusion_model_input is not None:
|
||||
# if diffusion_model_input.shape[1] == 9:
|
||||
# return config_inpainting
|
||||
# if diffusion_model_input.shape[1] == 8:
|
||||
# return config_instruct_pix2pix
|
||||
#
|
||||
# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
# return config_alt_diffusion_m18
|
||||
# return config_alt_diffusion
|
||||
#
|
||||
# return config_default
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config(state_dict, info):
|
||||
# if info is None:
|
||||
# return guess_model_config_from_state_dict(state_dict, "")
|
||||
#
|
||||
# config = find_checkpoint_config_near_filename(info)
|
||||
# if config is not None:
|
||||
# return config
|
||||
#
|
||||
# return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
#
|
||||
#
|
||||
# def find_checkpoint_config_near_filename(info):
|
||||
# if info is None:
|
||||
# return None
|
||||
#
|
||||
# config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||
# if os.path.exists(config):
|
||||
# return config
|
||||
#
|
||||
# return None
|
||||
#
|
||||
39
modules/sd_models_types.py
Executable file
39
modules/sd_models_types.py
Executable file
@@ -0,0 +1,39 @@
|
||||
# from typing import TYPE_CHECKING
|
||||
#
|
||||
#
|
||||
# if TYPE_CHECKING:
|
||||
# from modules.sd_models import CheckpointInfo
|
||||
#
|
||||
#
|
||||
# class WebuiSdModel:
|
||||
# """This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||
#
|
||||
# lowvram: bool
|
||||
# """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
|
||||
#
|
||||
# sd_model_hash: str
|
||||
# """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
|
||||
#
|
||||
# sd_model_checkpoint: str
|
||||
# """path to the file on disk that model weights were obtained from"""
|
||||
#
|
||||
# sd_checkpoint_info: 'CheckpointInfo'
|
||||
# """structure with additional information about the file with model's weights"""
|
||||
#
|
||||
# is_sdxl: bool
|
||||
# """True if the model's architecture is SDXL or SSD"""
|
||||
#
|
||||
# is_ssd: bool
|
||||
# """True if the model is SSD"""
|
||||
#
|
||||
# is_sd2: bool
|
||||
# """True if the model's architecture is SD 2.x"""
|
||||
#
|
||||
# is_sd1: bool
|
||||
# """True if the model's architecture is SD 1.x"""
|
||||
#
|
||||
# is_sd3: bool
|
||||
# """True if the model's architecture is SD 3"""
|
||||
#
|
||||
# latent_channels: int
|
||||
# """number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
|
||||
115
modules/sd_models_xl.py
Executable file
115
modules/sd_models_xl.py
Executable file
@@ -0,0 +1,115 @@
|
||||
# from __future__ import annotations
|
||||
#
|
||||
# import torch
|
||||
#
|
||||
# import sgm.models.diffusion
|
||||
# import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
# import sgm.modules.diffusionmodules.discretizer
|
||||
# from modules import devices, shared, prompt_parser
|
||||
# from modules import torch_utils
|
||||
#
|
||||
# from backend import memory_management
|
||||
#
|
||||
#
|
||||
# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
#
|
||||
# for embedder in self.conditioner.embedders:
|
||||
# embedder.ucg_rate = 0.0
|
||||
#
|
||||
# width = getattr(batch, 'width', 1024) or 1024
|
||||
# height = getattr(batch, 'height', 1024) or 1024
|
||||
# is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
#
|
||||
# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
#
|
||||
# sdxl_conds = {
|
||||
# "txt": batch,
|
||||
# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
# }
|
||||
#
|
||||
# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
#
|
||||
# return c
|
||||
#
|
||||
#
|
||||
# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
# if self.model.diffusion_model.in_channels == 9:
|
||||
# x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
#
|
||||
# return self.model(x, t, cond, *args, **kwargs)
|
||||
#
|
||||
#
|
||||
# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
# return x
|
||||
#
|
||||
#
|
||||
# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
#
|
||||
#
|
||||
# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
# res = []
|
||||
#
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
# encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
# res.append(encoded)
|
||||
#
|
||||
# return torch.cat(res, dim=1)
|
||||
#
|
||||
#
|
||||
# def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
# return embedder.tokenize(texts)
|
||||
#
|
||||
# raise AssertionError('no tokenizer available')
|
||||
#
|
||||
#
|
||||
#
|
||||
# def process_texts(self, texts):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
# return embedder.process_texts(texts)
|
||||
#
|
||||
#
|
||||
# def get_target_prompt_token_count(self, token_count):
|
||||
# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
# return embedder.get_target_prompt_token_count(token_count)
|
||||
#
|
||||
#
|
||||
# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
# sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
# sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
#
|
||||
#
|
||||
# def extend_sdxl(model):
|
||||
# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
#
|
||||
# dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
# model.model.diffusion_model.dtype = dtype
|
||||
# model.model.conditioning_key = 'crossattn'
|
||||
# model.cond_stage_key = 'txt'
|
||||
# # model.cond_stage_model will be set in sd_hijack
|
||||
#
|
||||
# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
#
|
||||
# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
#
|
||||
# model.conditioner.wrapped = torch.nn.Module()
|
||||
#
|
||||
#
|
||||
# sgm.modules.attention.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
# sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
#
|
||||
# # this gets the code to load the vanilla attention that we override
|
||||
# sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
||||
144
modules/sd_samplers.py
Executable file
144
modules/sd_samplers.py
Executable file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
from modules_forge import alter_samplers
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
*sd_samplers_lcm.samplers_data_lcm,
|
||||
*alter_samplers.samplers_data_alter
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers: list[sd_samplers_common.SamplerData] = []
|
||||
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
|
||||
samplers_map = {}
|
||||
samplers_hidden = {}
|
||||
|
||||
|
||||
def find_sampler_config(name):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
config = find_sampler_config(name)
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
if model.is_sdxl and config.options.get("no_sdxl", False):
|
||||
raise Exception(f"Sampler {config.name} is not supported for SDXL")
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img, samplers_hidden
|
||||
|
||||
samplers_hidden = set(shared.opts.hide_samplers)
|
||||
samplers = all_samplers
|
||||
samplers_for_img2img = all_samplers
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
samplers_map[sampler.name.lower()] = sampler.name
|
||||
for alias in sampler.aliases:
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
return
|
||||
|
||||
|
||||
def add_sampler(sampler):
|
||||
global all_samplers, all_samplers_map
|
||||
if sampler.name not in [x.name for x in all_samplers]:
|
||||
all_samplers.append(sampler)
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
set_samplers()
|
||||
return
|
||||
|
||||
|
||||
def visible_sampler_names():
|
||||
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
def visible_samplers():
|
||||
return [x for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
def get_sampler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||
|
||||
|
||||
def get_scheduler_from_infotext(d: dict):
|
||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||
|
||||
|
||||
def get_hr_sampler_and_scheduler(d: dict):
|
||||
hr_sampler = d.get("Hires sampler", "Use same sampler")
|
||||
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
|
||||
|
||||
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
|
||||
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
|
||||
|
||||
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
|
||||
|
||||
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
|
||||
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
|
||||
|
||||
return sampler, scheduler
|
||||
|
||||
|
||||
def get_hr_sampler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[0]
|
||||
|
||||
|
||||
def get_hr_scheduler_from_infotext(d: dict):
|
||||
return get_hr_sampler_and_scheduler(d)[1]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_sampler_and_scheduler(sampler_name, scheduler_name, *, convert_automatic=True):
|
||||
default_sampler = samplers[0]
|
||||
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||
|
||||
name = sampler_name or default_sampler.name
|
||||
|
||||
for scheduler in sd_schedulers.schedulers:
|
||||
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||
|
||||
for name_option in name_options:
|
||||
if name.endswith(" " + name_option):
|
||||
found_scheduler = scheduler
|
||||
name = name[0:-(len(name_option) + 1)]
|
||||
break
|
||||
|
||||
sampler = all_samplers_map.get(name, default_sampler)
|
||||
|
||||
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||
if convert_automatic and sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||
found_scheduler = sd_schedulers.schedulers[0]
|
||||
|
||||
return sampler.name, found_scheduler.label
|
||||
|
||||
|
||||
def fix_p_invalid_sampler_and_scheduler(p):
|
||||
i_sampler_name, i_scheduler = p.sampler_name, p.scheduler
|
||||
p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler, convert_automatic=False)
|
||||
if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler:
|
||||
logging.warning(f'Sampler Scheduler autocorrection: "{i_sampler_name}" -> "{p.sampler_name}", "{i_scheduler}" -> "{p.scheduler}"')
|
||||
|
||||
|
||||
set_samplers()
|
||||
228
modules/sd_samplers_cfg_denoiser.py
Executable file
228
modules/sd_samplers_cfg_denoiser.py
Executable file
@@ -0,0 +1,228 @@
|
||||
import torch
|
||||
from modules import prompt_parser, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
from backend.sampling.sampling_function import sampling_function
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__()
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
"""number of steps as specified by user in UI"""
|
||||
|
||||
self.total_steps = None
|
||||
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.padded_cond_uncond_v0 = False
|
||||
self.sampler = sampler
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
|
||||
self.need_last_noise_uncond = False
|
||||
self.last_noise_uncond = None
|
||||
|
||||
# Backward Compatibility
|
||||
self.mask_before_denoising = False
|
||||
|
||||
self.classic_ddim_eps_estimation = False
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def update_inner_model(self):
|
||||
self.model_wrap = None
|
||||
|
||||
c, uc = self.p.get_conds()
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def pad_cond_uncond(self, cond, uncond):
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
cond = pad_cond(cond, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def pad_cond_uncond_v0(self, cond, uncond):
|
||||
"""
|
||||
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
||||
|
||||
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
||||
If 'uncond' is a tensor, it is padded directly.
|
||||
|
||||
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
||||
is repeated to match the number of columns in 'cond'.
|
||||
|
||||
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
||||
to match the number of columns in 'cond'.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
||||
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
||||
|
||||
Note:
|
||||
This is the padding that was always used in DDIM before version 1.6.0
|
||||
"""
|
||||
|
||||
is_dict_cond = isinstance(uncond, dict)
|
||||
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
||||
|
||||
if uncond_vec.shape[1] < cond.shape[1]:
|
||||
last_vector = uncond_vec[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
||||
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
||||
self.padded_cond_uncond_v0 = True
|
||||
elif uncond_vec.shape[1] > cond.shape[1]:
|
||||
uncond_vec = uncond_vec[:, :cond.shape[1]]
|
||||
self.padded_cond_uncond_v0 = True
|
||||
|
||||
if is_dict_cond:
|
||||
uncond['crossattn'] = uncond_vec
|
||||
else:
|
||||
uncond = uncond_vec
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
original_x_device = x.device
|
||||
original_x_dtype = x.dtype
|
||||
|
||||
if self.classic_ddim_eps_estimation:
|
||||
acd = self.inner_model.inner_model.alphas_cumprod
|
||||
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
||||
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
||||
real_sigma_data = 1.0
|
||||
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
||||
sigma = real_sigma
|
||||
|
||||
if sd_samplers_common.apply_refiner(self, x):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None
|
||||
|
||||
if self.mask is not None:
|
||||
predictor = self.inner_model.inner_model.forge_objects.unet.model.predictor
|
||||
noisy_initial_latent = predictor.noise_scaling(sigma[:, None, None, None], torch.randn_like(self.init_latent).to(self.init_latent), self.init_latent, max_denoise=False)
|
||||
x = x * self.nmask + noisy_initial_latent * self.mask
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
|
||||
# NGMS
|
||||
if self.p.is_hr_pass == True:
|
||||
cond_scale = self.p.hr_cfg
|
||||
|
||||
if shared.opts.skip_early_cond > 0 and self.step / self.total_steps <= shared.opts.skip_early_cond:
|
||||
cond_scale = 1.0
|
||||
self.p.extra_generation_params["Skip Early CFG"] = shared.opts.skip_early_cond
|
||||
elif (self.step % 2 or shared.opts.s_min_uncond_all) and s_min_uncond > 0 and sigma[0] < s_min_uncond:
|
||||
cond_scale = 1.0
|
||||
self.p.extra_generation_params["NGMS"] = s_min_uncond
|
||||
if shared.opts.s_min_uncond_all:
|
||||
self.p.extra_generation_params["NGMS all steps"] = shared.opts.s_min_uncond_all
|
||||
|
||||
denoised, cond_pred, uncond_pred = sampling_function(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition)
|
||||
|
||||
if self.need_last_noise_uncond:
|
||||
self.last_noise_uncond = (x - uncond_pred) / sigma[:, None, None, None]
|
||||
|
||||
if self.mask is not None:
|
||||
blended_latent = denoised * self.nmask + self.init_latent * self.mask
|
||||
|
||||
if self.p.scripts is not None:
|
||||
from modules import scripts
|
||||
mba = scripts.MaskBlendArgs(denoised, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
|
||||
self.p.scripts.on_mask_blend(self.p, mba)
|
||||
blended_latent = mba.blended_latent
|
||||
|
||||
denoised = blended_latent
|
||||
|
||||
preview = self.sampler.last_latent = denoised
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
|
||||
if self.classic_ddim_eps_estimation:
|
||||
eps = (x - denoised) / sigma[:, None, None, None]
|
||||
return eps
|
||||
|
||||
return denoised.to(device=original_x_device, dtype=original_x_dtype)
|
||||
364
modules/sd_samplers_common.py
Executable file
364
modules/sd_samplers_common.py
Executable file
@@ -0,0 +1,364 @@
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||
from modules import extra_networks
|
||||
import k_diffusion.sampling
|
||||
from modules_forge import main_entry
|
||||
|
||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
||||
class SamplerData(SamplerDataTuple):
|
||||
def total_steps(self, steps):
|
||||
if self.options.get("second_order", False):
|
||||
steps = steps * 2
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
requested_steps = (steps or p.steps)
|
||||
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
t_enc = requested_steps - 1
|
||||
else:
|
||||
steps = p.steps
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
|
||||
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
|
||||
|
||||
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
if approximation == 0:
|
||||
approximation = 1
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
m = sd_vae_approx.model()
|
||||
if m is None:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
else:
|
||||
x_sample = m(sample.to(devices.device, devices.dtype)).detach()
|
||||
elif approximation == 3:
|
||||
m = sd_vae_taesd.decoder_model()
|
||||
if m is None:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
else:
|
||||
x_sample = m(sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
return x_sample
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = x_sample.cpu()
|
||||
x_sample.clamp_(0.0, 1.0)
|
||||
x_sample.mul_(255.)
|
||||
x_sample.round_()
|
||||
x_sample = x_sample.to(torch.uint8)
|
||||
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
'''image[0, 1] -> latent'''
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
|
||||
|
||||
if approximation == 3:
|
||||
image = image.to(devices.device, devices.dtype)
|
||||
x_latent = sd_vae_taesd.encoder_model()(image)
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
if len(image) > 1:
|
||||
x_latent = torch.stack([
|
||||
model.get_first_stage_encoding(
|
||||
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||
)[0]
|
||||
for img in image
|
||||
])
|
||||
else:
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
|
||||
return x_latent
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
||||
if not shared.parallel_processing_allowed:
|
||||
shared.state.assign_current_image(sample_to_image(decoded))
|
||||
|
||||
|
||||
def is_sampler_using_eta_noise_seed_delta(p):
|
||||
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||
|
||||
eta = p.eta
|
||||
|
||||
if eta is None and p.sampler is not None:
|
||||
eta = p.sampler.eta
|
||||
|
||||
if eta is None and sampler_config is not None:
|
||||
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
||||
|
||||
if eta == 0:
|
||||
return False
|
||||
|
||||
return sampler_config.options.get("uses_ensd", False)
|
||||
|
||||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
def replace_torchsde_browinan():
|
||||
import torchsde._brownian.brownian_interval
|
||||
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser, x):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
return False
|
||||
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
return False
|
||||
|
||||
if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||
is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||
|
||||
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass != "second pass":
|
||||
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
|
||||
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
|
||||
if checkpoint_changed:
|
||||
try:
|
||||
main_entry.refresh_model_loading_parameters()
|
||||
sd_models.forge_model_reload()
|
||||
finally:
|
||||
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
|
||||
|
||||
if not cfg_denoiser.p.disable_extra_networks:
|
||||
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
return True
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
"""This is here to replace torch.randn_like of k-diffusion.
|
||||
|
||||
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
||||
this is needed to properly replace every use of torch.randn_like.
|
||||
|
||||
We need to replace to make images generated in batches to be same as images generated individually."""
|
||||
|
||||
def __init__(self, p):
|
||||
self.rng = p.rng
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
return self.rng.next()
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, funcname):
|
||||
self.funcname = funcname
|
||||
self.func = funcname
|
||||
self.extra_params = []
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config: SamplerData = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.eta_option_field = 'eta_ancestral'
|
||||
self.eta_infotext_field = 'Eta'
|
||||
self.eta_default = 1.0
|
||||
|
||||
self.conditioning_key = 'crossattn'
|
||||
|
||||
self.p = None
|
||||
self.model_wrap_cfg = None
|
||||
self.sampler_extra_args = None
|
||||
self.options = {}
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.model_wrap_cfg.steps = steps
|
||||
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p) -> dict:
|
||||
self.p = p
|
||||
self.model_wrap_cfg.p = p
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(p)
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != self.eta_default:
|
||||
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_infotext(self, p):
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond_v0:
|
||||
p.extra_generation_params["Pad conds v0"] = True
|
||||
0
modules/sd_samplers_compvis.py
Executable file
0
modules/sd_samplers_compvis.py
Executable file
74
modules/sd_samplers_extra.py
Executable file
74
modules/sd_samplers_extra.py
Executable file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
|
||||
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
|
||||
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
|
||||
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
step_id = 0
|
||||
from k_diffusion.sampling import to_d, get_sigmas_karras
|
||||
|
||||
def heun_step(x, old_sigma, new_sigma, second_order=True):
|
||||
nonlocal step_id
|
||||
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||
d = to_d(x, old_sigma, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||
dt = new_sigma - old_sigma
|
||||
if new_sigma == 0 or not second_order:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
step_id += 1
|
||||
return x
|
||||
|
||||
steps = sigmas.shape[0] - 1
|
||||
if restart_list is None:
|
||||
if steps >= 20:
|
||||
restart_steps = 9
|
||||
restart_times = 1
|
||||
if steps >= 36:
|
||||
restart_steps = steps // 4
|
||||
restart_times = 2
|
||||
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
|
||||
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
|
||||
else:
|
||||
restart_list = {}
|
||||
|
||||
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
|
||||
|
||||
step_list = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
step_list.append((sigmas[i], sigmas[i + 1]))
|
||||
if i + 1 in restart_list:
|
||||
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||
min_idx = i + 1
|
||||
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||
if max_idx < min_idx:
|
||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||
while restart_times > 0:
|
||||
restart_times -= 1
|
||||
step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
|
||||
|
||||
last_sigma = None
|
||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||
if last_sigma is None:
|
||||
last_sigma = old_sigma
|
||||
elif last_sigma < old_sigma:
|
||||
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
|
||||
x = heun_step(x, old_sigma, new_sigma)
|
||||
last_sigma = new_sigma
|
||||
|
||||
return x
|
||||
246
modules/sd_samplers_kdiffusion.py
Executable file
246
modules/sd_samplers_kdiffusion.py
Executable file
@@ -0,0 +1,246 @@
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import k_diffusion.external
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||
('HeunPP2', 'sample_heunpp2', ['heunpp2'], {}),
|
||||
('IPNDM', 'sample_ipndm', ['ipndm'], {}),
|
||||
('IPNDM_V', 'sample_ipndm_v', ['ipndm_v'], {}),
|
||||
('DEIS', 'sample_deis', ['deis'], {}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_fast': ['s_noise'],
|
||||
'sample_dpm_2_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_sde': ['s_noise'],
|
||||
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||
}
|
||||
|
||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
||||
|
||||
|
||||
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
self.model_wrap = k_diffusion.external.ForgeScheduleLinker(shared.sd_model.forge_objects.unet.model.predictor)
|
||||
self.model_wrap.inner_model = shared.sd_model
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
|
||||
self.options = options or {}
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
|
||||
if scheduler_name == 'Automatic':
|
||||
scheduler_name = self.config.options.get('scheduler', None)
|
||||
|
||||
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
||||
|
||||
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif scheduler is None or scheduler.function is None:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
else:
|
||||
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||
|
||||
if scheduler.label != 'Automatic' and not p.is_hr_pass:
|
||||
p.extra_generation_params["Schedule type"] = scheduler.label
|
||||
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
|
||||
p.extra_generation_params["Hires schedule type"] = scheduler.label
|
||||
|
||||
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
||||
|
||||
if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
|
||||
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
||||
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
||||
|
||||
if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
|
||||
sigmas_kwargs['rho'] = opts.rho
|
||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||
|
||||
if scheduler.need_inner_model:
|
||||
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||
|
||||
if scheduler.label == 'Beta':
|
||||
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
|
||||
p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta
|
||||
|
||||
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
|
||||
|
||||
if discard_next_to_last_sigma:
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas.cpu()
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
|
||||
x = x.to(noise)
|
||||
|
||||
xi = self.model_wrap.predictor.noise_scaling(sigma_sched[0], noise, x, max_denoise=False)
|
||||
|
||||
if opts.img2img_extra_noise > 0:
|
||||
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||
extra_noise_params = ExtraNoiseParams(noise, x, xi)
|
||||
extra_noise_callback(extra_noise_params)
|
||||
noise = extra_noise_params.noise
|
||||
xi += noise * opts.img2img_extra_noise
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
if self.config.options.get('solver_type', None) == 'heun':
|
||||
extra_params_kwargs['solver_type'] = 'heun'
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||
|
||||
if opts.sgm_noise_multiplier:
|
||||
p.extra_generation_params["SGM noise multiplier"] = True
|
||||
|
||||
x = self.model_wrap.predictor.noise_scaling(sigmas[0], x, torch.zeros_like(x), max_denoise=opts.sgm_noise_multiplier)
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
if self.config.options.get('solver_type', None) == 'heun':
|
||||
extra_params_kwargs['solver_type'] = 'heun'
|
||||
|
||||
self.last_latent = x
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user