mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-04-24 00:09:21 +00:00
Merge branch 'dev' into npu_support
This commit is contained in:
@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
@@ -31,7 +31,7 @@ from typing import Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from contextlib import closing
|
||||
|
||||
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
@@ -230,6 +230,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
@@ -251,6 +252,24 @@ class Api:
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
txt2img_script_runner = scripts.scripts_txt2img
|
||||
img2img_script_runner = scripts.scripts_img2img
|
||||
|
||||
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
|
||||
ui.create_ui()
|
||||
|
||||
if not txt2img_script_runner.scripts:
|
||||
txt2img_script_runner.initialize_scripts(False)
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
|
||||
|
||||
if not img2img_script_runner.scripts:
|
||||
img2img_script_runner.initialize_scripts(True)
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||
|
||||
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
@@ -312,8 +331,13 @@ class Api:
|
||||
script_args[script.args_from:script.args_to] = ui_default_values
|
||||
return script_args
|
||||
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
|
||||
script_args = default_script_args.copy()
|
||||
|
||||
if input_script_args is not None:
|
||||
for index, value in input_script_args.items():
|
||||
script_args[index] = value
|
||||
|
||||
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||
if selectable_scripts:
|
||||
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||
@@ -335,13 +359,83 @@ class Api:
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
||||
"""Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
|
||||
|
||||
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
||||
|
||||
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
|
||||
"""
|
||||
|
||||
if not request.infotext:
|
||||
return {}
|
||||
|
||||
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
|
||||
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
|
||||
params = infotext_utils.parse_generation_parameters(request.infotext)
|
||||
|
||||
def get_field_value(field, params):
|
||||
value = field.function(params) if field.function else params.get(field.label)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if field.api in request.__fields__:
|
||||
target_type = request.__fields__[field.api].type_
|
||||
else:
|
||||
target_type = type(field.component.value)
|
||||
|
||||
if target_type == type(None):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
|
||||
value = value.get('value')
|
||||
|
||||
if value is not None and not isinstance(value, target_type):
|
||||
value = target_type(value)
|
||||
|
||||
return value
|
||||
|
||||
for field in possible_fields:
|
||||
if not field.api:
|
||||
continue
|
||||
|
||||
if field.api in set_fields:
|
||||
continue
|
||||
|
||||
value = get_field_value(field, params)
|
||||
if value is not None:
|
||||
setattr(request, field.api, value)
|
||||
|
||||
if request.override_settings is None:
|
||||
request.override_settings = {}
|
||||
|
||||
overriden_settings = infotext_utils.get_override_settings(params)
|
||||
for _, setting_name, value in overriden_settings:
|
||||
if setting_name not in request.override_settings:
|
||||
request.override_settings[setting_name] = value
|
||||
|
||||
if script_runner is not None and mentioned_script_args is not None:
|
||||
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
|
||||
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
|
||||
|
||||
for field, index in script_fields:
|
||||
value = get_field_value(field, params)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
mentioned_script_args[index] = value
|
||||
|
||||
return params
|
||||
|
||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
|
||||
|
||||
script_runner = scripts.scripts_txt2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
@@ -356,12 +450,15 @@ class Api:
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.is_api = True
|
||||
@@ -371,12 +468,14 @@ class Api:
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
@@ -386,6 +485,8 @@ class Api:
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||
task_id = img2imgreq.force_task_id or create_task_id("img2img")
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
@@ -395,11 +496,10 @@ class Api:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
script_runner = scripts.scripts_img2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(True)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
@@ -416,12 +516,15 @@ class Api:
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
@@ -432,12 +535,14 @@ class Api:
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
@@ -480,7 +585,7 @@ class Api:
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||
params = infotext_utils.parse_generation_parameters(geninfo)
|
||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||
@@ -511,7 +616,7 @@ class Api:
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
|
||||
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
@@ -643,6 +748,10 @@ class Api:
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
@@ -775,7 +884,15 @@ class Api:
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=server_name,
|
||||
port=port,
|
||||
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
|
||||
root_path=root_path,
|
||||
ssl_keyfile=shared.cmd_opts.tls_keyfile,
|
||||
ssl_certfile=shared.cmd_opts.tls_certfile
|
||||
)
|
||||
|
||||
def kill_webui(self):
|
||||
restart.stop_program()
|
||||
|
||||
@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str, "default": None},
|
||||
{"key": "infotext", "type": str, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str, "default": None},
|
||||
{"key": "infotext", "type": str, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
|
||||
@@ -62,16 +62,15 @@ def cache(subsection):
|
||||
if cache_data is None:
|
||||
with cache_lock:
|
||||
if cache_data is None:
|
||||
if not os.path.isfile(cache_filename):
|
||||
try:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
cache_data = {}
|
||||
except Exception:
|
||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
||||
cache_data = {}
|
||||
else:
|
||||
try:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
except Exception:
|
||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
||||
cache_data = {}
|
||||
|
||||
s = cache_data.get(subsection, {})
|
||||
cache_data[subsection] = s
|
||||
|
||||
@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.stopping_generation = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
|
||||
@@ -77,7 +77,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
|
||||
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
|
||||
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
@@ -86,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
@@ -1,435 +0,0 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions or [16]
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
@@ -1,132 +1,64 @@
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader, errors
|
||||
from modules.paths import models_path
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
# I am making a choice to include some files from codeformer to work around this issue.
|
||||
model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_download_name = 'codeformer-v0.1.0.pth'
|
||||
|
||||
codeformer = None
|
||||
# used by e.g. postprocessing_codeformer.py
|
||||
codeformer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
path = modules.paths.paths.get("CodeFormer", None)
|
||||
if path is None:
|
||||
return
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
expected_architecture='CodeFormer',
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
def get_device(self):
|
||||
return devices.device_codeformer
|
||||
|
||||
def restore(self, np_image, w: float | None = None):
|
||||
if w is None:
|
||||
w = getattr(shared.opts, "code_former_weight", 0.5)
|
||||
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, w=w, adain=True)[0]
|
||||
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def setup_model(dirname: str) -> None:
|
||||
global codeformer
|
||||
try:
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules.codeformer.codeformer_arch import CodeFormer
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
|
||||
net_class = CodeFormer
|
||||
|
||||
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
||||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def __init__(self, dirname):
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
self.cmd_dir = dirname
|
||||
|
||||
def create_models(self):
|
||||
|
||||
if self.net is not None and self.face_helper is not None:
|
||||
self.net.to(devices.device_codeformer)
|
||||
return self.net, self.face_helper
|
||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
||||
if len(model_paths) != 0:
|
||||
ckpt_path = model_paths[0]
|
||||
else:
|
||||
print("Unable to load codeformer model.")
|
||||
return None, None
|
||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = devices.device_codeformer
|
||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
||||
|
||||
self.net = net
|
||||
self.face_helper = face_helper
|
||||
|
||||
return net, face_helper
|
||||
|
||||
def send_model_to(self, device):
|
||||
self.net.to(device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, w=None):
|
||||
np_image = np_image[:, :, ::-1]
|
||||
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
self.create_models()
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.send_model_to(devices.device_codeformer)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = self.face_helper.paste_faces_to_input_image()
|
||||
restored_img = restored_img[:, :, ::-1]
|
||||
|
||||
if original_resolution != restored_img.shape[0:2]:
|
||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
return restored_img
|
||||
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
except Exception:
|
||||
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||
|
||||
# sys.path = stored_sys_path
|
||||
|
||||
79
modules/dat_model.py
Normal file
79
modules/dat_model.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerDAT(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "DAT"
|
||||
self.user_path = user_path
|
||||
self.scalers = []
|
||||
super().__init__()
|
||||
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
for model in get_dat_models(self):
|
||||
if model.name in opts.dat_enabled_models:
|
||||
self.scalers.append(model)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load DAT model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="DAT",
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.DAT_tile,
|
||||
tile_overlap=opts.DAT_tile_overlap,
|
||||
)
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
|
||||
def get_dat_models(scaler):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="DAT x2",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x3",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
||||
scale=3,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x4",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
]
|
||||
@@ -23,6 +23,23 @@ def has_mps() -> bool:
|
||||
return mac_specific.has_mps
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
@@ -79,8 +96,7 @@ def enable_tf32():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||
if cuda_no_autocast():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -90,6 +106,7 @@ def enable_tf32():
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
@@ -98,6 +115,7 @@ device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
dtype_inference: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
@@ -110,15 +128,89 @@ def cond_cast_float(input):
|
||||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
return result
|
||||
return forward_wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
applied = False
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if applied:
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
module_type.forward = module_type.org_forward
|
||||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
|
||||
@@ -107,8 +107,8 @@ def check_versions():
|
||||
import torch
|
||||
import gradio
|
||||
|
||||
expected_torch_version = "2.0.0"
|
||||
expected_xformers_version = "0.0.20"
|
||||
expected_torch_version = "2.1.2"
|
||||
expected_xformers_version = "0.0.23.post1"
|
||||
expected_gradio_version = "3.41.2"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
|
||||
@@ -1,121 +1,7 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import modelloader, images, devices
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if 'conv_first.weight' in state_dict:
|
||||
crt_net = {}
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
for k in items.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
||||
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
||||
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
||||
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
||||
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
||||
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def resrgan2normal(state_dict, nb=23):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
for k in items.copy():
|
||||
if "rdb" in k:
|
||||
ori_k = k.replace('body.', 'model.1.sub.')
|
||||
ori_k = ori_k.replace('.rdb', '.RDB')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
||||
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
||||
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
||||
|
||||
if 'conv_up3.weight' in state_dict:
|
||||
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
||||
re8x = 3
|
||||
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
||||
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
||||
|
||||
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
||||
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
||||
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
||||
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
||||
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def infer_params(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
scale2x = 0
|
||||
scalemin = 6
|
||||
n_uplayer = 0
|
||||
plus = False
|
||||
|
||||
for block in list(state_dict):
|
||||
parts = block.split(".")
|
||||
n_parts = len(parts)
|
||||
if n_parts == 5 and parts[2] == "sub":
|
||||
nb = int(parts[3])
|
||||
elif n_parts == 3:
|
||||
part_num = int(parts[1])
|
||||
if (part_num > scalemin
|
||||
and parts[0] == "model"
|
||||
and parts[2] == "weight"):
|
||||
scale2x += 1
|
||||
if part_num > n_uplayer:
|
||||
n_uplayer = part_num
|
||||
out_nc = state_dict[block].shape[0]
|
||||
if not plus and "conv1x1" in block:
|
||||
plus = True
|
||||
|
||||
nf = state_dict["model.0.weight"].shape[0]
|
||||
in_nc = state_dict["model.0.weight"].shape[1]
|
||||
out_nc = out_nc
|
||||
scale = 2 ** scale2x
|
||||
|
||||
return in_nc, out_nc, nf, nb, plus, scale
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
|
||||
def do_upscale(self, img, selected_model):
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
return esrgan_upscale(model, img)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if path.startswith("http"):
|
||||
@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
|
||||
else:
|
||||
filename = path
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
if "params_ema" in state_dict:
|
||||
state_dict = state_dict["params_ema"]
|
||||
elif "params" in state_dict:
|
||||
state_dict = state_dict["params"]
|
||||
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
||||
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
||||
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
||||
state_dict = resrgan2normal(state_dict, nb)
|
||||
elif "conv_first.weight" in state_dict:
|
||||
state_dict = mod2normal(state_dict)
|
||||
elif "model.0.weight" not in state_dict:
|
||||
raise Exception("The file is not a recognized ESRGAN model.")
|
||||
|
||||
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
||||
|
||||
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
return modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||
expected_architecture='ESRGAN',
|
||||
)
|
||||
|
||||
|
||||
def esrgan_upscale(model, img):
|
||||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for tiledata in row:
|
||||
x, w, tile = tiledata
|
||||
|
||||
output = upscale_without_tiling(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
||||
@@ -1,465 +0,0 @@
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = f'scale_factor={self.scale_factor}'
|
||||
else:
|
||||
info = f'size={self.size}'
|
||||
info += f', mode={self.mode}'
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
from torchvision.ops import DeformConv2d # not tested
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
@@ -32,11 +32,12 @@ class ExtensionMetadata:
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
filepath = os.path.join(path, self.filename)
|
||||
if os.path.isfile(filepath):
|
||||
try:
|
||||
self.config.read(filepath)
|
||||
except Exception:
|
||||
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
|
||||
# so no need to check whether the file exists beforehand.
|
||||
try:
|
||||
self.config.read(filepath)
|
||||
except Exception:
|
||||
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||
|
||||
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||
self.canonical_name = canonical_name.lower().strip()
|
||||
@@ -223,13 +224,16 @@ def list_extensions():
|
||||
|
||||
# check for requirements
|
||||
for extension in extensions:
|
||||
if not extension.enabled:
|
||||
continue
|
||||
|
||||
for req in extension.metadata.requires:
|
||||
required_extension = loaded_extensions.get(req)
|
||||
if required_extension is None:
|
||||
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||
continue
|
||||
|
||||
if not extension.enabled:
|
||||
if not required_extension.enabled:
|
||||
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||
continue
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def parse_prompts(prompts):
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename):
|
||||
def get_user_metadata(filename, lister=None):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
@@ -215,7 +215,8 @@ def get_user_metadata(filename):
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
|
||||
if exists:
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
|
||||
180
modules/face_restoration_utils.py
Normal file
180
modules/face_restoration_utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
||||
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
||||
assert img.shape[2] == 3, "image must be RGB"
|
||||
if img.dtype == "float64":
|
||||
img = img.astype("float32")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
|
||||
|
||||
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
||||
"""
|
||||
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
||||
"""
|
||||
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
assert tensor.dim() == 3, "tensor must be RGB"
|
||||
img_np = tensor.numpy().transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
||||
return np.squeeze(img_np, axis=2)
|
||||
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
def create_face_helper(device) -> FaceRestoreHelper:
|
||||
from facexlib.detection import retinaface
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = device
|
||||
return FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def restore_with_face_helper(
|
||||
np_image: np.ndarray,
|
||||
face_helper: FaceRestoreHelper,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||
|
||||
`restore_face` should take a cropped face image and return a restored face image.
|
||||
"""
|
||||
from torchvision.transforms.functional import normalize
|
||||
np_image = np_image[:, :, ::-1]
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
try:
|
||||
logger.debug("Detecting faces...")
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np_image)
|
||||
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||
for cropped_face in face_helper.cropped_faces:
|
||||
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
cropped_face_t = restore_face(cropped_face_t)
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed face-restoration inference', exc_info=True)
|
||||
|
||||
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
||||
restored_face = (restored_face * 255.0).astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
logger.debug("Merging restored faces into image")
|
||||
face_helper.get_inverse_affine(None)
|
||||
img = face_helper.paste_faces_to_input_image()
|
||||
img = img[:, :, ::-1]
|
||||
if original_resolution != img.shape[0:2]:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(0, 0),
|
||||
fx=original_resolution[1] / img.shape[1],
|
||||
fy=original_resolution[0] / img.shape[0],
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
logger.debug("Face restoration complete")
|
||||
finally:
|
||||
face_helper.clean_all()
|
||||
return img
|
||||
|
||||
|
||||
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
net: torch.Module | None
|
||||
model_url: str
|
||||
model_download_name: str
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__()
|
||||
self.net = None
|
||||
self.model_path = model_path
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
@cached_property
|
||||
def face_helper(self) -> FaceRestoreHelper:
|
||||
return create_face_helper(self.get_device())
|
||||
|
||||
def send_model_to(self, device):
|
||||
if self.net:
|
||||
logger.debug("Sending %s to %s", self.net, device)
|
||||
self.net.to(device)
|
||||
if self.face_helper:
|
||||
logger.debug("Sending face helper to %s", device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def get_device(self):
|
||||
raise NotImplementedError("get_device must be implemented by subclasses")
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
raise NotImplementedError("load_net must be implemented by subclasses")
|
||||
|
||||
def restore_with_helper(
|
||||
self,
|
||||
np_image: np.ndarray,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
if self.net is None:
|
||||
self.net = self.load_net()
|
||||
except Exception:
|
||||
logger.warning("Unable to load face-restoration model", exc_info=True)
|
||||
return np_image
|
||||
|
||||
try:
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
|
||||
def patch_facexlib(dirname: str) -> None:
|
||||
import facexlib.detection
|
||||
import facexlib.parsing
|
||||
|
||||
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
||||
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
||||
|
||||
def update_kwargs(kwargs):
|
||||
return dict(kwargs, save_dir=dirname, model_dir=None)
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
@@ -1,125 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import facexlib
|
||||
import gfpgan
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import paths, shared, devices, modelloader, errors
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_file_path = None
|
||||
logger = logging.getLogger(__name__)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
model_download_name = "GFPGANv1.4.pth"
|
||||
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
global model_file_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||
return loaded_gfpgan_model
|
||||
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
if gfpgan_constructor is None:
|
||||
return None
|
||||
def get_device(self):
|
||||
return devices.device_gfpgan
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
if 'GFPGAN' in os.path.basename(model_path):
|
||||
model = modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||
return model
|
||||
raise ValueError("No GFPGAN model found")
|
||||
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
gfp_models = []
|
||||
for item in models:
|
||||
if 'GFPGAN' in os.path.basename(item):
|
||||
gfp_models.append(item)
|
||||
latest_file = max(gfp_models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
def restore(self, np_image):
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, return_rgb=False)[0]
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model_file_path = model_file
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def send_model_to(model, device):
|
||||
model.gfpgan.to(device)
|
||||
model.face_helper.face_det.to(device)
|
||||
model.face_helper.face_parse.to(device)
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
|
||||
send_model_to(model, devices.device_gfpgan)
|
||||
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
model.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
send_model_to(model, devices.cpu)
|
||||
|
||||
if gfpgan_face_restorer:
|
||||
return gfpgan_face_restorer.restore(np_image)
|
||||
logger.warning("GFPGAN face restorer not set up")
|
||||
return np_image
|
||||
|
||||
|
||||
gfpgan_constructor = None
|
||||
def setup_model(dirname: str) -> None:
|
||||
global gfpgan_face_restorer
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
try:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
from gfpgan import GFPGANer
|
||||
from facexlib import detection, parsing # noqa: F401
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
global model_file_path
|
||||
|
||||
facexlib_path = model_path
|
||||
|
||||
if dirname is not None:
|
||||
facexlib_path = dirname
|
||||
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
def restore(self, np_image):
|
||||
return gfpgan_fix_faces(np_image)
|
||||
|
||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||
face_restoration_utils.patch_facexlib(dirname)
|
||||
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
||||
shared.face_restorers.append(gfpgan_face_restorer)
|
||||
except Exception:
|
||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||
|
||||
43
modules/hat_model.py
Normal file
43
modules/hat_model.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerHAT(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "HAT"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan) # TODO: should probably be device_hat
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
|
||||
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
|
||||
)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Model file {path} not found")
|
||||
return modelloader.load_spandrel_model(
|
||||
path,
|
||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||
expected_architecture='HAT',
|
||||
)
|
||||
@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||
return grid
|
||||
|
||||
|
||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
|
||||
@property
|
||||
def tile_count(self) -> int:
|
||||
"""
|
||||
The total number of tiles in the grid.
|
||||
"""
|
||||
return sum(len(row[2]) for row in self.tiles)
|
||||
|
||||
|
||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
w = image.width
|
||||
h = image.height
|
||||
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
||||
w, h = image.size
|
||||
|
||||
non_overlap_width = tile_w - overlap
|
||||
non_overlap_height = tile_h - overlap
|
||||
@@ -316,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
return res
|
||||
|
||||
|
||||
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
@@ -791,3 +796,4 @@ def flatten(img, bgcolor):
|
||||
img = background
|
||||
|
||||
return img.convert('RGB')
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
|
||||
import gradio as gr
|
||||
|
||||
from modules import images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.sd_models import get_closet_checkpoint_match
|
||||
@@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
if state.interrupted or state.stopping_generation:
|
||||
break
|
||||
|
||||
try:
|
||||
@@ -222,9 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = mask_blur
|
||||
|
||||
with closing(p):
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
||||
@@ -4,12 +4,15 @@ import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
|
||||
from PIL import Image
|
||||
|
||||
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
@@ -28,6 +31,19 @@ class ParamBinding:
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
class PasteField(tuple):
|
||||
def __new__(cls, component, target, *, api=None):
|
||||
return super().__new__(cls, (component, target))
|
||||
|
||||
def __init__(self, component, target, *, api=None):
|
||||
super().__init__()
|
||||
|
||||
self.api = api
|
||||
self.component = component
|
||||
self.label = target if isinstance(target, str) else None
|
||||
self.function = target if callable(target) else None
|
||||
|
||||
|
||||
paste_fields: dict[str, dict] = {}
|
||||
registered_param_bindings: list[ParamBinding] = []
|
||||
|
||||
@@ -84,6 +100,12 @@ def image_from_url_text(filedata):
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
||||
|
||||
if fields:
|
||||
for i in range(len(fields)):
|
||||
if not isinstance(fields[i], PasteField):
|
||||
fields[i] = PasteField(*fields[i])
|
||||
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
@@ -208,7 +230,7 @@ def restore_old_hires_fix_params(res):
|
||||
res['Hires resize-2'] = height
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
||||
@@ -218,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
|
||||
returns a dict with field values
|
||||
"""
|
||||
if skip_fields is None:
|
||||
skip_fields = shared.opts.infotext_skip_pasting
|
||||
|
||||
res = {}
|
||||
|
||||
@@ -290,6 +314,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Hires negative prompt" not in res:
|
||||
res["Hires negative prompt"] = ""
|
||||
|
||||
if "Mask mode" not in res:
|
||||
res["Mask mode"] = "Inpaint masked"
|
||||
|
||||
if "Masked content" not in res:
|
||||
res["Masked content"] = 'original'
|
||||
|
||||
if "Inpaint area" not in res:
|
||||
res["Inpaint area"] = "Whole picture"
|
||||
|
||||
if "Masked area padding" not in res:
|
||||
res["Masked area padding"] = 32
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
@@ -314,8 +350,16 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
skip = set(shared.opts.infotext_skip_pasting)
|
||||
res = {k: v for k, v in res.items() if k not in skip}
|
||||
if "FP8 weight" not in res:
|
||||
res["FP8 weight"] = "Disable"
|
||||
|
||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
infotext_versions.backcompat(res)
|
||||
|
||||
for key in skip_fields:
|
||||
res.pop(key, None)
|
||||
|
||||
return res
|
||||
|
||||
@@ -365,13 +409,57 @@ def create_override_settings_dict(text_pairs):
|
||||
return res
|
||||
|
||||
|
||||
def get_override_settings(params, *, skip_fields=None):
|
||||
"""Returns a list of settings overrides from the infotext parameters dictionary.
|
||||
|
||||
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
|
||||
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
|
||||
|
||||
It checks for conditions before adding an override:
|
||||
- ignores settings that match the current value
|
||||
- ignores parameter keys present in skip_fields argument.
|
||||
|
||||
Example input:
|
||||
{"Clip skip": "2"}
|
||||
|
||||
Example output:
|
||||
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
|
||||
"""
|
||||
|
||||
res = []
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in (skip_fields or {}):
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
res.append((param_name, setting_name, v))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
params = parse_generation_parameters(prompt)
|
||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||
@@ -393,6 +481,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
|
||||
if valtype == bool and v == "False":
|
||||
val = False
|
||||
elif valtype == int:
|
||||
val = float(v)
|
||||
else:
|
||||
val = valtype(v)
|
||||
|
||||
@@ -406,29 +496,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
vals = get_override_settings(params, skip_fields=already_handled_fields)
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in already_handled_fields:
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
vals[param_name] = v
|
||||
|
||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
||||
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
||||
|
||||
42
modules/infotext_versions.py
Normal file
42
modules/infotext_versions.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from modules import shared
|
||||
from packaging import version
|
||||
import re
|
||||
|
||||
|
||||
v160 = version.parse("1.6.0")
|
||||
v170_tsnr = version.parse("v1.7.0-225")
|
||||
|
||||
|
||||
def parse_version(text):
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
m = re.match(r'([^-]+-[^-]+)-.*', text)
|
||||
if m:
|
||||
text = m.group(1)
|
||||
|
||||
try:
|
||||
return version.parse(text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def backcompat(d):
|
||||
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
|
||||
|
||||
if not shared.opts.auto_backcompat:
|
||||
return
|
||||
|
||||
ver = parse_version(d.get("Version"))
|
||||
if ver is None:
|
||||
return
|
||||
|
||||
if ver < v160 and '[' in d.get('Prompt', ''):
|
||||
d["Old prompt editing timelines"] = True
|
||||
|
||||
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
||||
d["Pad conds v0"] = True
|
||||
|
||||
if ver < v170_tsnr:
|
||||
d["Downcast alphas_cumprod"] = True
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from threading import Thread
|
||||
@@ -18,6 +19,7 @@ def imports():
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
@@ -54,9 +56,6 @@ def initialize():
|
||||
initialize_util.configure_sigint_handler()
|
||||
initialize_util.configure_opts_onchange()
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.cleanup_models()
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.setup_model()
|
||||
startup_timer.record("setup SD model")
|
||||
|
||||
@@ -177,6 +177,8 @@ def configure_opts_onchange():
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ import torch.hub
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
re_topn = re.compile(r"\.top(\d+)$")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
@@ -131,7 +131,7 @@ class InterrogateModels:
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = next(self.clip_model.parameters()).dtype
|
||||
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
|
||||
@@ -27,8 +27,7 @@ dir_repos = "repositories"
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
|
||||
|
||||
def check_python_version():
|
||||
@@ -245,11 +244,13 @@ def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
if os.path.isfile(settings_file):
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report("Could not load settings", exc_info=True)
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
@@ -314,8 +315,8 @@ def requirements_met(requirements_file):
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||
if args.use_ipex:
|
||||
if platform.system() == "Windows":
|
||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||
@@ -338,20 +339,20 @@ def prepare_environment():
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
@@ -405,18 +406,14 @@ def prepare_environment():
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||
startup_timer.record("install CodeFormer requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
|
||||
@@ -1,41 +1,58 @@
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
def __init__(self, level=logging.INFO):
|
||||
super().__init__(level)
|
||||
def __init__(self, fallback_handler: logging.Handler):
|
||||
super().__init__()
|
||||
self.fallback_handler = fallback_handler
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
tqdm.write(msg)
|
||||
self.flush()
|
||||
# If there are active tqdm progress bars,
|
||||
# attempt to not interfere with them.
|
||||
if tqdm._instances:
|
||||
tqdm.write(self.format(record))
|
||||
else:
|
||||
self.fallback_handler.emit(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
self.fallback_handler.emit(record)
|
||||
|
||||
TQDM_IMPORTED = True
|
||||
except ImportError:
|
||||
# tqdm does not exist before first launch
|
||||
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||
TQDM_IMPORTED = False
|
||||
TqdmLoggingHandler = None
|
||||
|
||||
|
||||
def setup_logging(loglevel):
|
||||
if loglevel is None:
|
||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
|
||||
loghandlers = []
|
||||
if not loglevel:
|
||||
return
|
||||
|
||||
if TQDM_IMPORTED:
|
||||
loghandlers.append(TqdmLoggingHandler())
|
||||
if logging.root.handlers:
|
||||
# Already configured, do not interfere
|
||||
return
|
||||
|
||||
if loglevel:
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=loghandlers
|
||||
)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
if os.environ.get("SD_WEBUI_RICH_LOG"):
|
||||
from rich.logging import RichHandler
|
||||
handler = RichHandler()
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
if TqdmLoggingHandler:
|
||||
handler = TqdmLoggingHandler(handler)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.root.setLevel(log_level)
|
||||
logging.root.addHandler(handler)
|
||||
|
||||
@@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
def get_crop_region(mask, pad=0):
|
||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
return (
|
||||
int(max(crop_left-pad, 0)),
|
||||
int(max(crop_top-pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h))
|
||||
)
|
||||
For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
|
||||
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
||||
box = mask_img.getbbox()
|
||||
if box:
|
||||
x1, y1, x2, y2 = box
|
||||
else: # when no box is found
|
||||
x1, y1 = mask_img.size
|
||||
x2 = y2 = 0
|
||||
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
|
||||
|
||||
|
||||
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import spandrel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
@@ -90,54 +97,6 @@ def friendly_name(file: str):
|
||||
return model_name
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||
# somehow auto-register and just do these things...
|
||||
root_path = script_path
|
||||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
move_files(src_path, dest_path, ".ckpt")
|
||||
move_files(src_path, dest_path, ".safetensors")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(models_path, "BSRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path, ".pth")
|
||||
src_path = os.path.join(root_path, "gfpgan")
|
||||
dest_path = os.path.join(models_path, "GFPGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "SwinIR")
|
||||
dest_path = os.path.join(models_path, "SwinIR")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||
dest_path = os.path.join(models_path, "LDSR")
|
||||
move_files(src_path, dest_path)
|
||||
|
||||
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
try:
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
if os.path.exists(src_path):
|
||||
for file in os.listdir(src_path):
|
||||
fullpath = os.path.join(src_path, file)
|
||||
if os.path.isfile(fullpath):
|
||||
if ext_filter is not None:
|
||||
if ext_filter not in file:
|
||||
continue
|
||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||
try:
|
||||
shutil.move(fullpath, dest_path)
|
||||
except Exception:
|
||||
pass
|
||||
if len(os.listdir(src_path)) == 0:
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
@@ -177,3 +136,34 @@ def load_upscalers():
|
||||
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||
)
|
||||
|
||||
|
||||
def load_spandrel_model(
|
||||
path: str | os.PathLike,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
prefer_half: bool = False,
|
||||
dtype: str | torch.dtype | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
) -> spandrel.ModelDescriptor:
|
||||
import spandrel
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||
logger.warning(
|
||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||
)
|
||||
half = False
|
||||
if prefer_half:
|
||||
if model_descriptor.supports_half:
|
||||
model_descriptor.model.half()
|
||||
half = True
|
||||
else:
|
||||
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||
if dtype:
|
||||
model_descriptor.model.to(dtype=dtype)
|
||||
model_descriptor.model.eval()
|
||||
logger.debug(
|
||||
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||
model_descriptor, path, device, half, dtype,
|
||||
)
|
||||
return model_descriptor
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
@@ -6,6 +7,7 @@ import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
@@ -91,18 +93,35 @@ class Options:
|
||||
|
||||
if self.data is not None:
|
||||
if key in self.data or key in self.data_labels:
|
||||
|
||||
# Check that settings aren't globally frozen
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
# Get the info related to the setting being changed
|
||||
info = self.data_labels.get(key, None)
|
||||
if info.do_not_save:
|
||||
return
|
||||
|
||||
# Restrict component arguments
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted")
|
||||
|
||||
# Check that this section isn't frozen
|
||||
if cmd_opts.freeze_settings_in_sections is not None:
|
||||
frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
|
||||
section_key = info.section[0]
|
||||
section_name = info.section[1]
|
||||
assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
|
||||
|
||||
# Check that this section of the settings isn't frozen
|
||||
if cmd_opts.freeze_specific_settings is not None:
|
||||
frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
|
||||
assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
|
||||
|
||||
# Check shorthand option which disables editing options in "saving-paths"
|
||||
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
@@ -176,9 +195,15 @@ class Options:
|
||||
return type_x == type_y
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
|
||||
self.data = {}
|
||||
# 1.6.0 VAE defaults
|
||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||
|
||||
@@ -38,7 +38,6 @@ mute_sdxl_imports()
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
]
|
||||
|
||||
@@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
config_states_dir = os.path.join(script_path, "config_states")
|
||||
default_output_dir = os.path.join(data_path, "output")
|
||||
|
||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
@@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
else:
|
||||
image_data = image_placeholder
|
||||
|
||||
shared.state.assign_current_image(image_data)
|
||||
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
@@ -86,22 +84,25 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
basename = ''
|
||||
forced_filename = None
|
||||
|
||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pp.image.info = existing_pnginfo
|
||||
pp.image.info["postprocessing"] = infotext
|
||||
|
||||
shared.state.assign_current_image(pp.image)
|
||||
|
||||
if save_output:
|
||||
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
||||
|
||||
if pp.caption:
|
||||
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
||||
if os.path.isfile(caption_filename):
|
||||
existing_caption = ""
|
||||
try:
|
||||
with open(caption_filename, encoding="utf8") as file:
|
||||
existing_caption = file.read().strip()
|
||||
else:
|
||||
existing_caption = ""
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
action = shared.opts.postprocessing_existing_caption_action
|
||||
if action == 'Prepend' and existing_caption:
|
||||
|
||||
@@ -16,7 +16,7 @@ from skimage import exposure
|
||||
from typing import Any
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules.rng import slerp # noqa: F401
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||
@@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def apply_overlay(image, paste_loc, index, overlays):
|
||||
if overlays is None or index >= len(overlays):
|
||||
def uncrop(image, dest_size, paste_loc):
|
||||
x, y, w, h = paste_loc
|
||||
base_image = Image.new('RGBA', dest_size)
|
||||
image = images.resize_image(1, image, w, h)
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def apply_overlay(image, paste_loc, overlay):
|
||||
if overlay is None:
|
||||
return image
|
||||
|
||||
overlay = overlays[index]
|
||||
|
||||
if paste_loc is not None:
|
||||
x, y, w, h = paste_loc
|
||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
||||
image = images.resize_image(1, image, w, h)
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image
|
||||
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
|
||||
|
||||
image = image.convert('RGBA')
|
||||
image.alpha_composite(overlay)
|
||||
@@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||
|
||||
return image
|
||||
|
||||
def create_binary_mask(image):
|
||||
def create_binary_mask(image, round=True):
|
||||
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||
if round:
|
||||
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||
else:
|
||||
image = image.split()[-1].convert("L")
|
||||
else:
|
||||
image = image.convert('L')
|
||||
return image
|
||||
@@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||
|
||||
else:
|
||||
sd = sd_model.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||
image_conditioning = images_tensor_to_samples(image_conditioning,
|
||||
approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
@@ -157,6 +179,7 @@ class StableDiffusionProcessing:
|
||||
token_merging_ratio = 0
|
||||
token_merging_ratio_hr = 0
|
||||
disable_extra_networks: bool = False
|
||||
firstpass_image: Image = None
|
||||
|
||||
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||
script_args_value: list = field(default=None, init=False)
|
||||
@@ -308,7 +331,7 @@ class StableDiffusionProcessing:
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
return c_adm
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
@@ -320,8 +343,10 @@ class StableDiffusionProcessing:
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
if round_image_mask:
|
||||
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
|
||||
else:
|
||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||
|
||||
@@ -345,7 +370,7 @@ class StableDiffusionProcessing:
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
@@ -357,11 +382,17 @@ class StableDiffusionProcessing:
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
|
||||
|
||||
if self.sampler.conditioning_key == "crossattn-adm":
|
||||
return self.unclip_image_conditioning(source_image)
|
||||
|
||||
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
@@ -422,6 +453,8 @@ class StableDiffusionProcessing:
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
opts.fp8_storage,
|
||||
opts.cache_fp16_weight,
|
||||
)
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||
@@ -596,20 +629,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if check_for_nans:
|
||||
|
||||
try:
|
||||
devices.test_for_nans(sample, "vae")
|
||||
except devices.NansException as e:
|
||||
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
||||
if shared.opts.auto_vae_precision_bfloat16:
|
||||
autofix_dtype = torch.bfloat16
|
||||
autofix_dtype_text = "bfloat16"
|
||||
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
|
||||
autofix_dtype_comment = ""
|
||||
elif shared.opts.auto_vae_precision:
|
||||
autofix_dtype = torch.float32
|
||||
autofix_dtype_text = "32-bit float"
|
||||
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
|
||||
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
else:
|
||||
raise e
|
||||
|
||||
if devices.dtype_vae == autofix_dtype:
|
||||
raise e
|
||||
|
||||
errors.print_error_explanation(
|
||||
"A tensor with all NaNs was produced in VAE.\n"
|
||||
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
||||
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
|
||||
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
|
||||
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
|
||||
)
|
||||
|
||||
devices.dtype_vae = torch.float32
|
||||
devices.dtype_vae = autofix_dtype
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
batch = batch.to(devices.dtype_vae)
|
||||
|
||||
@@ -679,12 +725,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
||||
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Denoising strength": p.extra_generation_params.get("Denoising strength"),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||
@@ -699,7 +747,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"User": p.user if opts.add_user_name_to_info else None,
|
||||
}
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
||||
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
||||
@@ -818,7 +866,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
if state.interrupted or state.stopping_generation:
|
||||
break
|
||||
|
||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||
@@ -864,9 +912,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
|
||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
ps = scripts.PostSampleArgs(samples_ddim)
|
||||
p.scripts.post_sample(p, ps)
|
||||
samples_ddim = ps.samples
|
||||
|
||||
if getattr(samples_ddim, 'already_decoded', False):
|
||||
x_samples_ddim = samples_ddim
|
||||
else:
|
||||
@@ -922,13 +1003,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
mask_for_overlay = getattr(p, "mask_for_overlay", None)
|
||||
|
||||
if not shared.opts.overlay_inpaint:
|
||||
overlay_image = None
|
||||
elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
|
||||
overlay_image = p.overlay_images[i]
|
||||
else:
|
||||
overlay_image = None
|
||||
|
||||
if p.scripts is not None:
|
||||
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
|
||||
p.scripts.postprocess_maskoverlay(p, ppmo)
|
||||
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
# If the intention is to show the output from the model
|
||||
# that is being composited over the original image,
|
||||
# we need to keep the original image around
|
||||
# and use it in the composite step.
|
||||
original_denoised_image = image.copy()
|
||||
|
||||
if p.paste_to is not None:
|
||||
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, overlay_image)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image_after_composite(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if save_samples:
|
||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||
@@ -938,16 +1048,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.enable_pnginfo:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
|
||||
if mask_for_overlay is not None:
|
||||
if opts.return_mask or opts.save_mask:
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask = mask_for_overlay.convert('RGB')
|
||||
if save_samples and opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
if opts.return_mask_composite or opts.save_mask_composite:
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
if save_samples and opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||
if opts.return_mask_composite:
|
||||
@@ -1025,6 +1136,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
hr_sampler_name: str = None
|
||||
hr_prompt: str = ''
|
||||
hr_negative_prompt: str = ''
|
||||
force_task_id: str = None
|
||||
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
@@ -1097,7 +1209,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if self.hr_checkpoint_name:
|
||||
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||
|
||||
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
|
||||
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||
|
||||
if self.hr_checkpoint_info is None:
|
||||
@@ -1124,8 +1238,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if not state.processing_has_refined_job_count:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter
|
||||
|
||||
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
||||
if getattr(self, 'txt2img_upscale', False):
|
||||
total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
|
||||
else:
|
||||
total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
|
||||
shared.total_tqdm.updateTotal(total_steps)
|
||||
state.job_count = state.job_count * 2
|
||||
state.processing_has_refined_job_count = True
|
||||
|
||||
@@ -1138,18 +1255,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
x = self.rng.next()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
if self.firstpass_image is not None and self.enable_hr:
|
||||
# here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
devices.torch_gc()
|
||||
if self.latent_scale_mode is None:
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
|
||||
samples = None
|
||||
decoded_samples = torch.asarray(np.expand_dims(image, 0))
|
||||
|
||||
else:
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
image = torch.from_numpy(np.expand_dims(image, axis=0))
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
|
||||
samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
|
||||
decoded_samples = None
|
||||
devices.torch_gc()
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
else:
|
||||
decoded_samples = None
|
||||
# here we generate an image normally
|
||||
|
||||
x = self.rng.next()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
else:
|
||||
decoded_samples = None
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||
@@ -1351,12 +1495,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
mask_blur_x: int = 4
|
||||
mask_blur_y: int = 4
|
||||
mask_blur: int = None
|
||||
mask_round: bool = True
|
||||
inpainting_fill: int = 0
|
||||
inpaint_full_res: bool = True
|
||||
inpaint_full_res_padding: int = 0
|
||||
inpainting_mask_invert: int = 0
|
||||
initial_noise_multiplier: float = None
|
||||
latent_mask: Image = None
|
||||
force_task_id: str = None
|
||||
|
||||
image_mask: Any = field(default=None, init=False)
|
||||
|
||||
@@ -1386,6 +1532,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.mask_blur_y = value
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||
|
||||
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
@@ -1396,10 +1544,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if image_mask is not None:
|
||||
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
||||
# but we still want to support binary masks.
|
||||
image_mask = create_binary_mask(image_mask)
|
||||
image_mask = create_binary_mask(image_mask, round=self.mask_round)
|
||||
|
||||
if self.inpainting_mask_invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
self.extra_generation_params["Mask mode"] = "Inpaint not masked"
|
||||
|
||||
if self.mask_blur_x > 0:
|
||||
np_mask = np.array(image_mask)
|
||||
@@ -1413,16 +1562,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||
image_mask = Image.fromarray(np_mask)
|
||||
|
||||
if self.mask_blur_x > 0 or self.mask_blur_y > 0:
|
||||
self.extra_generation_params["Mask blur"] = self.mask_blur
|
||||
|
||||
if self.inpaint_full_res:
|
||||
self.mask_for_overlay = image_mask
|
||||
mask = image_mask.convert('L')
|
||||
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
||||
crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
|
||||
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
mask = mask.crop(crop_region)
|
||||
image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
||||
|
||||
self.extra_generation_params["Inpaint area"] = "Only masked"
|
||||
self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
|
||||
else:
|
||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||
np_mask = np.array(image_mask)
|
||||
@@ -1442,7 +1597,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
# Save init image
|
||||
if opts.save_init_img:
|
||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
|
||||
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
@@ -1464,6 +1619,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if self.inpainting_fill != 1:
|
||||
image = masking.fill(image, latent_mask)
|
||||
|
||||
if self.inpainting_fill == 0:
|
||||
self.extra_generation_params["Masked content"] = 'fill'
|
||||
|
||||
if add_color_corrections:
|
||||
self.color_corrections.append(setup_color_correction(image))
|
||||
|
||||
@@ -1503,7 +1661,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||
latmask = latmask[0]
|
||||
latmask = np.around(latmask)
|
||||
if self.mask_round:
|
||||
latmask = np.around(latmask)
|
||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||
|
||||
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
||||
@@ -1512,10 +1671,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
# this needs to be fixed to be done in sample() using actual seeds for batches
|
||||
if self.inpainting_fill == 2:
|
||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
||||
self.extra_generation_params["Masked content"] = 'latent noise'
|
||||
|
||||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
self.extra_generation_params["Masked content"] = 'latent nothing'
|
||||
|
||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
|
||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
x = self.rng.next()
|
||||
@@ -1527,7 +1689,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
blended_samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
||||
if self.scripts is not None:
|
||||
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
|
||||
self.scripts.on_mask_blend(self, mba)
|
||||
blended_samples = mba.blended_latent
|
||||
|
||||
samples = blended_samples
|
||||
|
||||
del x
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_models
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||
return None if info is None else info.title
|
||||
|
||||
self.infotext_fields = [
|
||||
(enable_refiner, lambda d: 'Refiner' in d),
|
||||
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||
(refiner_switch_at, 'Refiner switch at'),
|
||||
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||
]
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
||||
@@ -3,8 +3,10 @@ import json
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui, errors
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui_components import ToolButton
|
||||
from modules import infotext_utils
|
||||
|
||||
|
||||
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
@@ -51,12 +53,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||
|
||||
self.infotext_fields = [
|
||||
(self.seed, "Seed"),
|
||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
(subseed, "Variation seed"),
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
PasteField(self.seed, "Seed", api="seed"),
|
||||
PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
PasteField(subseed, "Variation seed", api="subseed"),
|
||||
PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
|
||||
PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
|
||||
PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
|
||||
]
|
||||
|
||||
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||
@@ -76,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
p.seed_resize_from_h = seed_resize_from_h
|
||||
|
||||
|
||||
|
||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||
@@ -84,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
|
||||
def copy_seed(gen_info_string: str, index):
|
||||
res = -1
|
||||
|
||||
try:
|
||||
gen_info = json.loads(gen_info_string)
|
||||
index -= gen_info.get('index_of_first_image', 0)
|
||||
|
||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||
else:
|
||||
all_seeds = gen_info.get('all_seeds', [-1])
|
||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
infotext = gen_info.get('infotexts')[index]
|
||||
gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
|
||||
res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
|
||||
except Exception:
|
||||
if gen_info_string:
|
||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||
errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
|
||||
|
||||
return [res, gr.update()]
|
||||
|
||||
|
||||
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
from collections import OrderedDict
|
||||
import string
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
current_task = None
|
||||
pending_tasks = {}
|
||||
pending_tasks = OrderedDict()
|
||||
finished_tasks = []
|
||||
recorded_results = []
|
||||
recorded_results_limit = 2
|
||||
@@ -34,6 +37,11 @@ def finish_task(id_task):
|
||||
if len(finished_tasks) > 16:
|
||||
finished_tasks.pop(0)
|
||||
|
||||
def create_task_id(task_type):
|
||||
N = 7
|
||||
res = ''.join(random.choices(string.ascii_uppercase +
|
||||
string.digits, k=N))
|
||||
return f"task({task_type}-{res})"
|
||||
|
||||
def record_results(id_task, res):
|
||||
recorded_results.append((id_task, res))
|
||||
@@ -44,6 +52,9 @@ def record_results(id_task, res):
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
class PendingTasksResponse(BaseModel):
|
||||
size: int = Field(title="Pending task size")
|
||||
tasks: List[str] = Field(title="Pending task ids")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
@@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
|
||||
|
||||
|
||||
def setup_progress_api(app):
|
||||
app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
|
||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||
|
||||
|
||||
def get_pending_tasks():
|
||||
pending_tasks_ids = list(pending_tasks)
|
||||
pending_len = len(pending_tasks_ids)
|
||||
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
|
||||
|
||||
|
||||
def progressapi(req: ProgressRequest):
|
||||
active = req.id_task == current_task
|
||||
queued = req.id_task in pending_tasks
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.name = "RealESRGAN"
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||
from realesrgan import RealESRGANer # noqa: F401
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = get_realesrgan_models(self)
|
||||
|
||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||
for scaler in scalers:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||
if local_model_candidates:
|
||||
scaler.local_data_path = local_model_candidates[0]
|
||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||
for scaler in scalers:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||
if local_model_candidates:
|
||||
scaler.local_data_path = local_model_candidates[0]
|
||||
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
except Exception:
|
||||
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
||||
self.enable = False
|
||||
self.scalers = []
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
if not self.enable:
|
||||
@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.local_data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
# TODO: `outscale`?
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
||||
|
||||
def get_realesrgan_models(scaler):
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
models = [
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
]
|
||||
return models
|
||||
except Exception:
|
||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -41,7 +41,7 @@ class ExtraNoiseParams:
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
@@ -63,6 +63,9 @@ class CFGDenoiserParams:
|
||||
self.text_uncond = text_uncond
|
||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||
|
||||
self.denoiser = denoiser
|
||||
"""Current CFGDenoiser object with processing parameters"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||
|
||||
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
class MaskBlendArgs:
|
||||
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
|
||||
self.current_latent = current_latent
|
||||
self.nmask = nmask
|
||||
self.init_latent = init_latent
|
||||
self.mask = mask
|
||||
self.blended_latent = blended_latent
|
||||
|
||||
self.denoiser = denoiser
|
||||
self.is_final_blend = denoiser is None
|
||||
self.sigma = sigma
|
||||
|
||||
class PostSampleArgs:
|
||||
def __init__(self, samples):
|
||||
self.samples = samples
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
class PostProcessMaskOverlayArgs:
|
||||
def __init__(self, index, mask_for_overlay, overlay_image):
|
||||
self.index = index
|
||||
self.mask_for_overlay = mask_for_overlay
|
||||
self.overlay_image = overlay_image
|
||||
|
||||
class PostprocessBatchListArgs:
|
||||
def __init__(self, images):
|
||||
@@ -71,6 +91,9 @@ class Script:
|
||||
setup_for_ui_only = False
|
||||
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||
|
||||
controls = None
|
||||
"""A list of controls retured by the ui()."""
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
@@ -206,6 +229,25 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
|
||||
"""
|
||||
Called in inpainting mode when the original content is blended with the inpainted content.
|
||||
This is called at every step in the denoising process and once at the end.
|
||||
If is_final_blend is true, this is called for the final blending stage.
|
||||
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def post_sample(self, p, ps: PostSampleArgs, *args):
|
||||
"""
|
||||
Called after the samples have been generated,
|
||||
but before they have been decoded by the VAE, if applicable.
|
||||
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
@@ -213,6 +255,22 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
Same as postprocess_image but after inpaint_full_res composite
|
||||
So that it operates on the full image instead of the inpaint_full_res crop region.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
@@ -520,7 +578,12 @@ class ScriptRunner:
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_data in auto_processing_scripts + scripts_data:
|
||||
script = script_data.script_class()
|
||||
try:
|
||||
script = script_data.script_class()
|
||||
except Exception:
|
||||
errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
|
||||
continue
|
||||
|
||||
script.filename = script_data.path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
@@ -573,6 +636,7 @@ class ScriptRunner:
|
||||
import modules.api.models as api_models
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
script.controls = controls
|
||||
|
||||
if controls is None:
|
||||
return
|
||||
@@ -645,6 +709,8 @@ class ScriptRunner:
|
||||
self.setup_ui_for_section(None, self.selectable_scripts)
|
||||
|
||||
def select_script(script_index):
|
||||
if script_index is None:
|
||||
script_index = 0
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
@@ -688,7 +754,7 @@ class ScriptRunner:
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
if script_index == 0 or script_index is None:
|
||||
return None
|
||||
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
@@ -767,6 +833,22 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||
|
||||
def post_sample(self, p, ps: PostSampleArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.post_sample(p, ps, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||
|
||||
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.on_mask_blend(p, mba, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
@@ -775,6 +857,22 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image_after_composite(p, pp, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||
try:
|
||||
@@ -841,6 +939,35 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||
|
||||
def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
|
||||
"""Locate an arg of a specific script in script_args and set its value
|
||||
Args:
|
||||
args: all script args of process p, p.script_args
|
||||
script_name: the name target script name to
|
||||
arg_elem_id: the elem_id of the target arg
|
||||
value: the value to set
|
||||
fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
|
||||
Returns:
|
||||
Updated script args
|
||||
when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
|
||||
"""
|
||||
script = next((x for x in self.scripts if x.name == script_name), None)
|
||||
if script is None:
|
||||
raise RuntimeError(f"script {script_name} not found")
|
||||
|
||||
for i, control in enumerate(script.controls):
|
||||
if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
|
||||
index = script.args_from + i
|
||||
|
||||
if isinstance(args, tuple):
|
||||
return args[:index] + (value,) + args[index + 1:]
|
||||
elif isinstance(args, list):
|
||||
args[index] = value
|
||||
return args
|
||||
else:
|
||||
raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
|
||||
raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
|
||||
@@ -11,10 +11,14 @@ class CondFunc:
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
try:
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
except AttributeError:
|
||||
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
|
||||
pass
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
|
||||
@@ -348,10 +348,28 @@ class SkipWritingToConfig:
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
if model is None:
|
||||
return None
|
||||
if devices.get_optimal_device_name() == "mps":
|
||||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
@@ -383,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
timer.record("apply float()")
|
||||
else:
|
||||
@@ -396,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
@@ -404,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
@@ -651,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
else:
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
}
|
||||
|
||||
@@ -746,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
@@ -758,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
if check_fp8(sd_model) != devices.fp8:
|
||||
# load from state dict again to prevent extra numerical errors
|
||||
forced_reload = True
|
||||
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
@@ -793,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
|
||||
@@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
else:
|
||||
return config_sdxl
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
|
||||
@@ -6,6 +6,7 @@ import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
@@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||
sd = self.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
|
||||
return self.model(x, t, cond)
|
||||
|
||||
|
||||
@@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
@@ -93,7 +100,7 @@ def extend_sdxl(model):
|
||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
*sd_samplers_lcm.samplers_data_lcm,
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
|
||||
@@ -53,9 +53,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.padded_cond_uncond_v0 = False
|
||||
self.sampler = sampler
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
|
||||
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||
# as the original latents do not have noise
|
||||
self.mask_before_denoising = False
|
||||
|
||||
@property
|
||||
@@ -88,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def pad_cond_uncond(self, cond, uncond):
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
cond = pad_cond(cond, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def pad_cond_uncond_v0(self, cond, uncond):
|
||||
"""
|
||||
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
||||
|
||||
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
||||
If 'uncond' is a tensor, it is padded directly.
|
||||
|
||||
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
||||
is repeated to match the number of columns in 'cond'.
|
||||
|
||||
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
||||
to match the number of columns in 'cond'.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
||||
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
||||
|
||||
Note:
|
||||
This is the padding that was always used in DDIM before version 1.6.0
|
||||
"""
|
||||
|
||||
is_dict_cond = isinstance(uncond, dict)
|
||||
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
||||
|
||||
if uncond_vec.shape[1] < cond.shape[1]:
|
||||
last_vector = uncond_vec[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
||||
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
||||
self.padded_cond_uncond_v0 = True
|
||||
elif uncond_vec.shape[1] > cond.shape[1]:
|
||||
uncond_vec = uncond_vec[:, :cond.shape[1]]
|
||||
self.padded_cond_uncond_v0 = True
|
||||
|
||||
if is_dict_cond:
|
||||
uncond['crossattn'] = uncond_vec
|
||||
else:
|
||||
uncond = uncond_vec
|
||||
|
||||
return cond, uncond
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
@@ -105,8 +165,21 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||
def apply_blend(current_latent):
|
||||
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
|
||||
|
||||
if self.p.scripts is not None:
|
||||
from modules import scripts
|
||||
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
|
||||
self.p.scripts.on_mask_blend(self.p, mba)
|
||||
blended_latent = mba.blended_latent
|
||||
|
||||
return blended_latent
|
||||
|
||||
# Blend in the original latents (before)
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
x = apply_blend(x)
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
@@ -130,7 +203,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
@@ -146,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
self.padded_cond_uncond_v0 = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
|
||||
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
|
||||
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
@@ -207,8 +275,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
# Blend in the original latents (after)
|
||||
if not self.mask_before_denoising and self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
denoised = apply_blend(denoised)
|
||||
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
|
||||
@@ -335,3 +335,10 @@ class Sampler:
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_infotext(self, p):
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond_v0:
|
||||
p.extra_generation_params["Pad conds v0"] = True
|
||||
|
||||
@@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
104
modules/sd_samplers_lcm.py
Normal file
104
modules/sd_samplers_lcm.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
|
||||
from k_diffusion import utils, sampling
|
||||
from k_diffusion.external import DiscreteEpsDDPMDenoiser
|
||||
from k_diffusion.sampling import default_noise_sampler, trange
|
||||
|
||||
from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
|
||||
|
||||
|
||||
class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
def __init__(self, model):
|
||||
timesteps = 1000
|
||||
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
|
||||
self.skip_steps = timesteps // original_timesteps
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
|
||||
for x in range(original_timesteps):
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
super().__init__(model, alphas_cumprod_valid, quantize=None)
|
||||
|
||||
|
||||
def get_sigmas(self, n=None,):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
|
||||
start = self.sigma_to_t(self.sigma_max)
|
||||
end = self.sigma_to_t(self.sigma_min)
|
||||
|
||||
t = torch.linspace(start, end, n, device=shared.sd_model.device)
|
||||
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||
|
||||
|
||||
def t_to_sigma(self, timestep):
|
||||
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||
return super().t_to_sigma(t)
|
||||
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
def get_scaled_out(self, sigma, output, input):
|
||||
sigma_data = 0.5
|
||||
scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
|
||||
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
|
||||
return c_out * output + c_skip * input
|
||||
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return self.get_scaled_out(sigma, input + eps * c_out, input)
|
||||
|
||||
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
return x
|
||||
|
||||
|
||||
class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = LCMCompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname, sd_model, options)
|
||||
self.model_wrap_cfg = CFGDenoiserLCM(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
|
||||
|
||||
samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
|
||||
samplers_data_lcm = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_lcm
|
||||
]
|
||||
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
self.inner_model = model
|
||||
|
||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
self.eta_default = 0.0
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
|
||||
def get_timesteps(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
@@ -132,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -157,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
}
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
self.add_infotext(p)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
return sd_model
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
@@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
|
||||
|
||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||
parallel_processing_allowed = True
|
||||
styles_filename = cmd_opts.styles_file
|
||||
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
|
||||
@@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None):
|
||||
except Exception as e:
|
||||
errors.display(e, "changing gradio theme")
|
||||
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
|
||||
# append additional values gradio_theme
|
||||
shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity
|
||||
shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity
|
||||
|
||||
@@ -18,8 +18,10 @@ def initialize():
|
||||
shared.options_templates = shared_options.options_templates
|
||||
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
||||
shared.restricted_opts = shared_options.restricted_opts
|
||||
if os.path.exists(shared.config_filename):
|
||||
try:
|
||||
shared.opts.load(shared.config_filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
from modules import devices
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
@@ -27,6 +29,7 @@ def initialize():
|
||||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
||||
|
||||
shared.device = devices.device
|
||||
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
@@ -8,6 +8,11 @@ def realesrgan_models_names():
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
def dat_models_names():
|
||||
import modules.dat_model
|
||||
return [x.name for x in modules.dat_model.get_dat_models(None)]
|
||||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
|
||||
@@ -67,14 +72,14 @@ def reload_hypernetworks():
|
||||
|
||||
|
||||
def get_infotext_names():
|
||||
from modules import generation_parameters_copypaste, shared
|
||||
from modules import infotext_utils, shared
|
||||
res = {}
|
||||
|
||||
for info in shared.opts.data_labels.values():
|
||||
if info.infotext:
|
||||
res[info.infotext] = 1
|
||||
|
||||
for tab_data in generation_parameters_copypaste.paste_fields.values():
|
||||
for tab_data in infotext_utils.paste_fields.values():
|
||||
for _, name in tab_data.get("fields") or []:
|
||||
if isinstance(name, str):
|
||||
res[name] = 1
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||
|
||||
@@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||
"outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
|
||||
"outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
|
||||
"outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
"outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
||||
@@ -96,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}),
|
||||
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||
}))
|
||||
|
||||
@@ -114,6 +118,7 @@ options_templates.update(options_section(('system', "System", "system"), {
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
@@ -176,6 +181,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||
"auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
@@ -195,6 +201,7 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||
"overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||
@@ -203,12 +210,16 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; ignored if the above is set; changes seeds"),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
"auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."),
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
@@ -216,6 +227,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||
@@ -244,6 +256,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||
"extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
@@ -267,6 +280,8 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
|
||||
"sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||
"sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
@@ -279,6 +294,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||
"interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||
@@ -354,6 +370,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||
|
||||
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
|
||||
class State:
|
||||
skipped = False
|
||||
interrupted = False
|
||||
stopping_generation = False
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
@@ -79,6 +80,10 @@ class State:
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def stop_generating(self):
|
||||
self.stopping_generation = True
|
||||
log.info("Received stop generating request")
|
||||
|
||||
def nextjob(self):
|
||||
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
@@ -91,6 +96,7 @@ class State:
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"stopping_generation": self.stopping_generation,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
@@ -114,6 +120,7 @@ class State:
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.stopping_generation = False
|
||||
self.textinfo = None
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import fnmatch
|
||||
import os
|
||||
import os.path
|
||||
import typing
|
||||
import shutil
|
||||
|
||||
|
||||
class PromptStyle(typing.NamedTuple):
|
||||
name: str
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
path: str = None
|
||||
prompt: str | None
|
||||
negative_prompt: str | None
|
||||
path: str | None = None
|
||||
|
||||
|
||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||
@@ -30,38 +29,29 @@ def apply_styles_to_prompt(prompt, styles):
|
||||
return prompt
|
||||
|
||||
|
||||
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||
"""
|
||||
Checks the prompt to see if the style text is wrapped around it. If so,
|
||||
returns True plus the prompt text without the style text. Otherwise, returns
|
||||
False with the original prompt.
|
||||
def extract_style_text_from_prompt(style_text, prompt):
|
||||
"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
|
||||
|
||||
Note that the "cleaned" version of the style text is only used for matching
|
||||
purposes here. It isn't returned; the original style text is not modified.
|
||||
extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
|
||||
extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
|
||||
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
|
||||
"""
|
||||
stripped_prompt = prompt
|
||||
stripped_style_text = style_text
|
||||
|
||||
stripped_prompt = prompt.strip()
|
||||
stripped_style_text = style_text.strip()
|
||||
|
||||
if "{prompt}" in stripped_style_text:
|
||||
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||
# return True and the "inner" prompt text that isn't part of the style.
|
||||
try:
|
||||
left, right = stripped_style_text.split("{prompt}", 2)
|
||||
except ValueError as e:
|
||||
# If the style text has multple "{prompt}"s, we can't split it into
|
||||
# two parts. This is an error, but we can't do anything about it.
|
||||
print(f"Unable to compare style text to prompt:\n{style_text}")
|
||||
print(f"Error: {e}")
|
||||
return False, prompt
|
||||
left, right = stripped_style_text.split("{prompt}", 2)
|
||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
||||
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
||||
return True, prompt
|
||||
else:
|
||||
# Work out whether the given prompt ends with the style text. If so, we
|
||||
# return True and the prompt text up to where the style text starts.
|
||||
if stripped_prompt.endswith(stripped_style_text):
|
||||
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
||||
if prompt.endswith(", "):
|
||||
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
||||
|
||||
if prompt.endswith(', '):
|
||||
prompt = prompt[:-2]
|
||||
|
||||
return True, prompt
|
||||
|
||||
return False, prompt
|
||||
@@ -76,15 +66,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||
if not style.prompt and not style.negative_prompt:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
||||
style.prompt, prompt
|
||||
)
|
||||
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
||||
if not match_positive:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
||||
style.negative_prompt, negative_prompt
|
||||
)
|
||||
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
||||
if not match_negative:
|
||||
return False, prompt, negative_prompt
|
||||
|
||||
@@ -92,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||
|
||||
|
||||
class StyleDatabase:
|
||||
def __init__(self, path: str):
|
||||
def __init__(self, paths: list[str | Path]):
|
||||
self.no_style = PromptStyle("None", "", "", None)
|
||||
self.styles = {}
|
||||
self.path = path
|
||||
self.paths = paths
|
||||
self.all_styles_files: list[Path] = []
|
||||
|
||||
folder, file = os.path.split(self.path)
|
||||
filename, _, ext = file.partition('*')
|
||||
self.default_path = os.path.join(folder, filename + ext)
|
||||
folder, file = os.path.split(self.paths[0])
|
||||
if '*' in file or '?' in file:
|
||||
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
||||
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
||||
self.paths.insert(0, self.default_path)
|
||||
else:
|
||||
self.default_path = Path(self.paths[0])
|
||||
|
||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||
|
||||
@@ -112,33 +103,31 @@ class StyleDatabase:
|
||||
"""
|
||||
self.styles.clear()
|
||||
|
||||
path, filename = os.path.split(self.path)
|
||||
# scans for all styles files
|
||||
all_styles_files = []
|
||||
for pattern in self.paths:
|
||||
folder, file = os.path.split(pattern)
|
||||
if '*' in file or '?' in file:
|
||||
found_files = Path(folder).glob(file)
|
||||
[all_styles_files.append(file) for file in found_files]
|
||||
else:
|
||||
# if os.path.exists(pattern):
|
||||
all_styles_files.append(Path(pattern))
|
||||
|
||||
if "*" in filename:
|
||||
fileglob = filename.split("*")[0] + "*.csv"
|
||||
filelist = []
|
||||
for file in os.listdir(path):
|
||||
if fnmatch.fnmatch(file, fileglob):
|
||||
filelist.append(file)
|
||||
# Add a visible divider to the style list
|
||||
half_len = round(len(file) / 2)
|
||||
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
||||
divider = f"{divider} {'-' * (40 - len(divider))}"
|
||||
self.styles[divider] = PromptStyle(
|
||||
f"{divider}", None, None, "do_not_save"
|
||||
)
|
||||
# Add styles from this CSV file
|
||||
self.load_from_csv(os.path.join(path, file))
|
||||
if len(filelist) == 0:
|
||||
print(f"No styles found in {path} matching {fileglob}")
|
||||
return
|
||||
elif not os.path.exists(self.path):
|
||||
print(f"Style database not found: {self.path}")
|
||||
return
|
||||
else:
|
||||
self.load_from_csv(self.path)
|
||||
# Remove any duplicate entries
|
||||
seen = set()
|
||||
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
||||
|
||||
def load_from_csv(self, path: str):
|
||||
for styles_file in self.all_styles_files:
|
||||
if len(all_styles_files) > 1:
|
||||
# add divider when more than styles file
|
||||
# '---------------- STYLES ----------------'
|
||||
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
||||
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
||||
if styles_file.is_file():
|
||||
self.load_from_csv(styles_file)
|
||||
|
||||
def load_from_csv(self, path: str | Path):
|
||||
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||
reader = csv.DictReader(file, skipinitialspace=True)
|
||||
for row in reader:
|
||||
@@ -150,7 +139,7 @@ class StyleDatabase:
|
||||
negative_prompt = row.get("negative_prompt", "")
|
||||
# Add style to database
|
||||
self.styles[row["name"]] = PromptStyle(
|
||||
row["name"], prompt, negative_prompt, path
|
||||
row["name"], prompt, negative_prompt, str(path)
|
||||
)
|
||||
|
||||
def get_style_paths(self) -> set:
|
||||
@@ -158,11 +147,11 @@ class StyleDatabase:
|
||||
# Update any styles without a path to the default path
|
||||
for style in list(self.styles.values()):
|
||||
if not style.path:
|
||||
self.styles[style.name] = style._replace(path=self.default_path)
|
||||
self.styles[style.name] = style._replace(path=str(self.default_path))
|
||||
|
||||
# Create a list of all distinct paths, including the default path
|
||||
style_paths = set()
|
||||
style_paths.add(self.default_path)
|
||||
style_paths.add(str(self.default_path))
|
||||
for _, style in self.styles.items():
|
||||
if style.path:
|
||||
style_paths.add(style.path)
|
||||
@@ -190,7 +179,6 @@ class StyleDatabase:
|
||||
|
||||
def save_styles(self, path: str = None) -> None:
|
||||
# The path argument is deprecated, but kept for backwards compatibility
|
||||
_ = path
|
||||
|
||||
style_paths = self.get_style_paths()
|
||||
|
||||
|
||||
@@ -24,13 +24,13 @@ environment_whitelist = {
|
||||
"XFORMERS_PACKAGE",
|
||||
"CLIP_PACKAGE",
|
||||
"OPENCLIP_PACKAGE",
|
||||
"ASSETS_REPO",
|
||||
"STABLE_DIFFUSION_REPO",
|
||||
"K_DIFFUSION_REPO",
|
||||
"CODEFORMER_REPO",
|
||||
"BLIP_REPO",
|
||||
"ASSETS_COMMIT_HASH",
|
||||
"STABLE_DIFFUSION_COMMIT_HASH",
|
||||
"K_DIFFUSION_COMMIT_HASH",
|
||||
"CODEFORMER_COMMIT_HASH",
|
||||
"BLIP_COMMIT_HASH",
|
||||
"COMMANDLINE_ARGS",
|
||||
"IGNORE_CMD_ARGS_ERRORS",
|
||||
|
||||
@@ -11,7 +11,6 @@ import safetensors.torch
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
import modules.textual_inversion.dataset
|
||||
@@ -348,6 +347,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
})
|
||||
|
||||
def tensorboard_setup(log_directory):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||
return SummaryWriter(
|
||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||
@@ -452,8 +452,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
tensorboard_writer = None
|
||||
if shared.opts.training_enable_tensorboard:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
try:
|
||||
tensorboard_writer = tensorboard_setup(log_directory)
|
||||
except ImportError:
|
||||
errors.report("Error initializing tensorboard", exc_info=True)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
@@ -626,7 +630,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
||||
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
|
||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
|
||||
17
modules/torch_utils.py
Normal file
17
modules/torch_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
||||
def get_param(model) -> torch.nn.Parameter:
|
||||
"""
|
||||
Find the first parameter in a model or module.
|
||||
"""
|
||||
if hasattr(model, "model") and hasattr(model.model, "parameters"):
|
||||
# Unpeel a model descriptor to get at the actual Torch module.
|
||||
model = model.model
|
||||
|
||||
for param in model.parameters():
|
||||
return param
|
||||
|
||||
raise ValueError(f"No parameters found in model {model!r}")
|
||||
@@ -1,17 +1,22 @@
|
||||
import json
|
||||
from contextlib import closing
|
||||
|
||||
import modules.scripts
|
||||
from modules import processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules import processing, infotext_utils
|
||||
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules.ui import plaintext_to_html
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
if force_enable_hr:
|
||||
enable_hr = True
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
@@ -27,7 +32,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
width=width,
|
||||
height=height,
|
||||
enable_hr=enable_hr,
|
||||
denoising_strength=denoising_strength if enable_hr else None,
|
||||
denoising_strength=denoising_strength,
|
||||
hr_scale=hr_scale,
|
||||
hr_upscaler=hr_upscaler,
|
||||
hr_second_pass_steps=hr_second_pass_steps,
|
||||
@@ -48,8 +53,58 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
assert len(gallery) > 0, 'No image to upscale'
|
||||
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
|
||||
|
||||
p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
|
||||
p.batch_size = 1
|
||||
p.n_iter = 1
|
||||
# txt2img_upscale attribute that signifies this is called by txt2img_upscale
|
||||
p.txt2img_upscale = True
|
||||
|
||||
geninfo = json.loads(generation_info)
|
||||
|
||||
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
|
||||
p.firstpass_image = infotext_utils.image_from_url_text(image_info)
|
||||
|
||||
parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
|
||||
p.seed = parameters.get('Seed', -1)
|
||||
p.subseed = parameters.get('Variation seed', -1)
|
||||
|
||||
p.override_settings['save_images_before_highres_fix'] = False
|
||||
|
||||
with closing(p):
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
|
||||
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
new_gallery = []
|
||||
for i, image in enumerate(gallery):
|
||||
if i == gallery_index:
|
||||
geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
|
||||
new_gallery.extend(processed.images)
|
||||
else:
|
||||
fake_image = Image.new(mode="RGB", size=(1, 1))
|
||||
fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
|
||||
new_gallery.append(fake_image)
|
||||
|
||||
geninfo["infotexts"][gallery_index] = processed.info
|
||||
|
||||
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def txt2img(id_task: str, request: gr.Request, *args):
|
||||
p = txt2img_create_processing(id_task, request, *args)
|
||||
|
||||
with closing(p):
|
||||
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
|
||||
|
||||
if processed is None:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
177
modules/ui.py
177
modules/ui.py
@@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import gradio_extensons # noqa: F401
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||
from modules.paths import script_path
|
||||
from modules.ui_common import create_refresh_button
|
||||
@@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript
|
||||
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.infotext_utils as parameters_copypaste
|
||||
import modules.hypernetworks.ui as hypernetworks_ui
|
||||
import modules.textual_inversion.ui as textual_inversion_ui
|
||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||
import modules.shared as shared
|
||||
from modules import prompt_parser
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.infotext_utils import image_from_url_text, PasteField
|
||||
|
||||
create_setting_component = ui_settings.create_setting_component
|
||||
|
||||
@@ -177,7 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
|
||||
return update_token_counter(text, steps, is_positive=False)
|
||||
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -267,7 +266,7 @@ def create_ui():
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"])
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||
@@ -376,50 +375,60 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||
output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||
|
||||
txt2img_inputs = [
|
||||
dummy_component,
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
toprow.ui_styles.dropdown,
|
||||
steps,
|
||||
sampler_name,
|
||||
batch_count,
|
||||
batch_size,
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
enable_hr,
|
||||
denoising_strength,
|
||||
hr_scale,
|
||||
hr_upscaler,
|
||||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
hr_checkpoint_name,
|
||||
hr_sampler_name,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
override_settings,
|
||||
] + custom_inputs
|
||||
|
||||
txt2img_outputs = [
|
||||
output_panel.gallery,
|
||||
output_panel.generation_info,
|
||||
output_panel.infotext,
|
||||
output_panel.html_log,
|
||||
]
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||
_js="submit",
|
||||
inputs=[
|
||||
dummy_component,
|
||||
toprow.prompt,
|
||||
toprow.negative_prompt,
|
||||
toprow.ui_styles.dropdown,
|
||||
steps,
|
||||
sampler_name,
|
||||
batch_count,
|
||||
batch_size,
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
enable_hr,
|
||||
denoising_strength,
|
||||
hr_scale,
|
||||
hr_upscaler,
|
||||
hr_second_pass_steps,
|
||||
hr_resize_x,
|
||||
hr_resize_y,
|
||||
hr_checkpoint_name,
|
||||
hr_sampler_name,
|
||||
hr_prompt,
|
||||
hr_negative_prompt,
|
||||
override_settings,
|
||||
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
],
|
||||
inputs=txt2img_inputs,
|
||||
outputs=txt2img_outputs,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
toprow.prompt.submit(**txt2img_args)
|
||||
toprow.submit.click(**txt2img_args)
|
||||
|
||||
output_panel.button_upscale.click(
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
|
||||
_js="submit_txt2img_upscale",
|
||||
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
|
||||
outputs=txt2img_outputs,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
||||
|
||||
toprow.restore_progress_button.click(
|
||||
@@ -427,37 +436,37 @@ def create_ui():
|
||||
_js="restoreProgressTxt2img",
|
||||
inputs=[dummy_component],
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
output_panel.gallery,
|
||||
output_panel.generation_info,
|
||||
output_panel.infotext,
|
||||
output_panel.html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt2img_paste_fields = [
|
||||
(toprow.prompt, "Prompt"),
|
||||
(toprow.negative_prompt, "Negative prompt"),
|
||||
(steps, "Steps"),
|
||||
(sampler_name, "Sampler"),
|
||||
(cfg_scale, "CFG scale"),
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
(batch_size, "Batch size"),
|
||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||
(hr_scale, "Hires upscale"),
|
||||
(hr_upscaler, "Hires upscaler"),
|
||||
(hr_second_pass_steps, "Hires steps"),
|
||||
(hr_resize_x, "Hires resize-1"),
|
||||
(hr_resize_y, "Hires resize-2"),
|
||||
(hr_checkpoint_name, "Hires checkpoint"),
|
||||
(hr_sampler_name, "Hires sampler"),
|
||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||
(hr_prompt, "Hires prompt"),
|
||||
(hr_negative_prompt, "Hires negative prompt"),
|
||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
||||
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
||||
PasteField(steps, "Steps", api="steps"),
|
||||
PasteField(sampler_name, "Sampler", api="sampler_name"),
|
||||
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
||||
PasteField(width, "Size-1", api="width"),
|
||||
PasteField(height, "Size-2", api="height"),
|
||||
PasteField(batch_size, "Batch size", api="batch_size"),
|
||||
PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
|
||||
PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
|
||||
PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
|
||||
PasteField(hr_scale, "Hires upscale", api="hr_scale"),
|
||||
PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
|
||||
PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
|
||||
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
||||
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
||||
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
||||
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
|
||||
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
||||
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
||||
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||
*scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||
@@ -480,7 +489,7 @@ def create_ui():
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
|
||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
|
||||
|
||||
extra_tabs.__exit__()
|
||||
|
||||
@@ -490,7 +499,7 @@ def create_ui():
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
||||
|
||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"])
|
||||
extra_tabs.__enter__()
|
||||
|
||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||
@@ -523,7 +532,7 @@ def create_ui():
|
||||
|
||||
if category == "image":
|
||||
with gr.Tabs(elem_id="mode_img2img"):
|
||||
img2img_selected_tab = gr.State(0)
|
||||
img2img_selected_tab = gr.Number(value=0, visible=False)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||
@@ -604,7 +613,7 @@ def create_ui():
|
||||
elif category == "dimensions":
|
||||
with FormRow():
|
||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||
selected_scale_tab = gr.State(value=0)
|
||||
selected_scale_tab = gr.Number(value=0, visible=False)
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||
@@ -711,7 +720,7 @@ def create_ui():
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||
output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
@@ -756,10 +765,10 @@ def create_ui():
|
||||
img2img_batch_png_info_dir,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
output_panel.gallery,
|
||||
output_panel.generation_info,
|
||||
output_panel.infotext,
|
||||
output_panel.html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -797,10 +806,10 @@ def create_ui():
|
||||
_js="restoreProgressImg2img",
|
||||
inputs=[dummy_component],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
generation_info,
|
||||
html_info,
|
||||
html_log,
|
||||
output_panel.gallery,
|
||||
output_panel.generation_info,
|
||||
output_panel.infotext,
|
||||
output_panel.html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -831,6 +840,10 @@ def create_ui():
|
||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(mask_blur, "Mask blur"),
|
||||
(inpainting_mask_invert, 'Mask mode'),
|
||||
(inpainting_fill, 'Masked content'),
|
||||
(inpaint_full_res, 'Inpaint area'),
|
||||
(inpaint_full_res_padding, 'Masked area padding'),
|
||||
*scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
||||
@@ -840,7 +853,7 @@ def create_ui():
|
||||
))
|
||||
|
||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
|
||||
|
||||
extra_tabs.__exit__()
|
||||
|
||||
@@ -1086,6 +1099,7 @@ def create_ui():
|
||||
)
|
||||
|
||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||
ui_settings_from_file = loadsave.ui_settings.copy()
|
||||
|
||||
settings = ui_settings.UiSettings()
|
||||
settings.create_ui(loadsave, dummy_component)
|
||||
@@ -1146,7 +1160,8 @@ def create_ui():
|
||||
|
||||
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
||||
|
||||
loadsave.dump_defaults()
|
||||
if ui_settings_from_file != loadsave.ui_settings:
|
||||
loadsave.dump_defaults()
|
||||
demo.ui_loadsave = loadsave
|
||||
|
||||
return demo
|
||||
@@ -1208,3 +1223,5 @@ def setup_ui_api(app):
|
||||
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
|
||||
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
|
||||
|
||||
import fastapi.staticfiles
|
||||
app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import dataclasses
|
||||
import json
|
||||
import html
|
||||
import os
|
||||
@@ -8,10 +10,10 @@ import gradio as gr
|
||||
import subprocess as sp
|
||||
|
||||
from modules import call_queue, shared
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.infotext_utils import image_from_url_text
|
||||
import modules.images
|
||||
from modules.ui_components import ToolButton
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.infotext_utils as parameters_copypaste
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
@@ -35,12 +37,38 @@ def plaintext_to_html(text, classname=None):
|
||||
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
||||
|
||||
|
||||
def update_logfile(logfile_path, fields):
|
||||
"""Update a logfile from old format to new format to maintain CSV integrity."""
|
||||
with open(logfile_path, "r", encoding="utf8", newline="") as file:
|
||||
reader = csv.reader(file)
|
||||
rows = list(reader)
|
||||
|
||||
# blank file: leave it as is
|
||||
if not rows:
|
||||
return
|
||||
|
||||
# file is already synced, do nothing
|
||||
if len(rows[0]) == len(fields):
|
||||
return
|
||||
|
||||
rows[0] = fields
|
||||
|
||||
# append new fields to each row as empty values
|
||||
for row in rows[1:]:
|
||||
while len(row) < len(fields):
|
||||
row.append("")
|
||||
|
||||
with open(logfile_path, "w", encoding="utf8", newline="") as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerows(rows)
|
||||
|
||||
|
||||
def save_files(js_data, images, do_make_zip, index):
|
||||
import csv
|
||||
filenames = []
|
||||
fullfns = []
|
||||
parsed_infotexts = []
|
||||
|
||||
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||
# quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||
class MyObject:
|
||||
def __init__(self, d=None):
|
||||
if d is not None:
|
||||
@@ -48,35 +76,55 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
setattr(self, key, value)
|
||||
|
||||
data = json.loads(js_data)
|
||||
|
||||
p = MyObject(data)
|
||||
|
||||
path = shared.opts.outdir_save
|
||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||
extension: str = shared.opts.samples_format
|
||||
start_index = 0
|
||||
only_one = False
|
||||
|
||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||
only_one = True
|
||||
images = [images[index]]
|
||||
start_index = index
|
||||
|
||||
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
||||
|
||||
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||
fields = [
|
||||
"prompt",
|
||||
"seed",
|
||||
"width",
|
||||
"height",
|
||||
"sampler",
|
||||
"cfgs",
|
||||
"steps",
|
||||
"filename",
|
||||
"negative_prompt",
|
||||
"sd_model_name",
|
||||
"sd_model_hash",
|
||||
]
|
||||
logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
|
||||
|
||||
# NOTE: ensure csv integrity when fields are added by
|
||||
# updating headers and padding with delimeters where needed
|
||||
if os.path.exists(logfile_path):
|
||||
update_logfile(logfile_path, fields)
|
||||
|
||||
with open(logfile_path, "a", encoding="utf8", newline='') as file:
|
||||
at_start = file.tell() == 0
|
||||
writer = csv.writer(file)
|
||||
if at_start:
|
||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||
writer.writerow(fields)
|
||||
|
||||
for image_index, filedata in enumerate(images, start_index):
|
||||
image = image_from_url_text(filedata)
|
||||
|
||||
is_grid = image_index < p.index_of_first_image
|
||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||
|
||||
p.batch_index = image_index-1
|
||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||
|
||||
parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], [])
|
||||
parsed_infotexts.append(parameters)
|
||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||
|
||||
filename = os.path.relpath(fullfn, path)
|
||||
filenames.append(filename)
|
||||
@@ -85,12 +133,12 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
filenames.append(os.path.basename(txt_fullfn))
|
||||
fullfns.append(txt_fullfn)
|
||||
|
||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||
writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
|
||||
|
||||
# Make Zip
|
||||
if do_make_zip:
|
||||
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
||||
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
||||
p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]
|
||||
namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)
|
||||
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
||||
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
||||
|
||||
@@ -104,7 +152,17 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OutputPanel:
|
||||
gallery = None
|
||||
generation_info = None
|
||||
infotext = None
|
||||
html_log = None
|
||||
button_upscale = None
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir, toprow=None):
|
||||
res = OutputPanel()
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
@@ -136,9 +194,8 @@ Requested path was: {f}
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||
|
||||
generation_info = None
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||
|
||||
@@ -152,6 +209,9 @@ Requested path was: {f}
|
||||
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
||||
}
|
||||
|
||||
if tabname == 'txt2img':
|
||||
res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
|
||||
|
||||
open_folder_button.click(
|
||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||
inputs=[],
|
||||
@@ -162,17 +222,17 @@ Requested path was: {f}
|
||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||
|
||||
with gr.Group():
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
||||
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
||||
|
||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||
if tabname == 'txt2img' or tabname == 'img2img':
|
||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||
generation_info_button.click(
|
||||
fn=update_generation_info,
|
||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||
inputs=[generation_info, html_info, html_info],
|
||||
outputs=[html_info, html_info],
|
||||
inputs=[res.generation_info, res.infotext, res.infotext],
|
||||
outputs=[res.infotext, res.infotext],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -180,14 +240,14 @@ Requested path was: {f}
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
res.generation_info,
|
||||
res.gallery,
|
||||
res.infotext,
|
||||
res.infotext,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
res.html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -196,21 +256,21 @@ Requested path was: {f}
|
||||
fn=call_queue.wrap_gradio_call(save_files),
|
||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||
inputs=[
|
||||
generation_info,
|
||||
result_gallery,
|
||||
html_info,
|
||||
html_info,
|
||||
res.generation_info,
|
||||
res.gallery,
|
||||
res.infotext,
|
||||
res.infotext,
|
||||
],
|
||||
outputs=[
|
||||
download_files,
|
||||
html_log,
|
||||
res.html_log,
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||
|
||||
paste_field_names = []
|
||||
if tabname == "txt2img":
|
||||
@@ -220,11 +280,11 @@ Requested path was: {f}
|
||||
|
||||
for paste_tabname, paste_button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
|
||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery,
|
||||
paste_field_names=paste_field_names
|
||||
))
|
||||
|
||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
||||
return res
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
|
||||
@@ -2,23 +2,22 @@ import functools
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
|
||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||
import gradio as gr
|
||||
import json
|
||||
import html
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
from modules.ui_components import ToolButton
|
||||
from modules.infotext_utils import image_from_url_text
|
||||
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
|
||||
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def allowed_preview_extensions_with_extra(extra_extensions=None):
|
||||
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
|
||||
@@ -28,6 +27,62 @@ def allowed_preview_extensions():
|
||||
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraNetworksItem:
|
||||
"""Wrapper for dictionaries representing ExtraNetworks items."""
|
||||
item: dict
|
||||
|
||||
|
||||
def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict:
|
||||
"""Recursively builds a directory tree.
|
||||
|
||||
Args:
|
||||
paths: Path or list of paths to directories. These paths are treated as roots from which
|
||||
the tree will be built.
|
||||
items: A dictionary associating filepaths to an ExtraNetworksItem instance.
|
||||
|
||||
Returns:
|
||||
The result directory tree.
|
||||
"""
|
||||
if isinstance(paths, (str,)):
|
||||
paths = [paths]
|
||||
|
||||
def _get_tree(_paths: list[str], _root: str):
|
||||
_res = {}
|
||||
for path in _paths:
|
||||
relpath = os.path.relpath(path, _root)
|
||||
if os.path.isdir(path):
|
||||
dir_items = os.listdir(path)
|
||||
# Ignore empty directories.
|
||||
if not dir_items:
|
||||
continue
|
||||
dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root)
|
||||
# We only want to store non-empty folders in the tree.
|
||||
if dir_tree:
|
||||
_res[relpath] = dir_tree
|
||||
else:
|
||||
if path not in items:
|
||||
continue
|
||||
# Add the ExtraNetworksItem to the result.
|
||||
_res[relpath] = items[path]
|
||||
return _res
|
||||
|
||||
res = {}
|
||||
# Handle each root directory separately.
|
||||
# Each root WILL have a key/value at the root of the result dict though
|
||||
# the value can be an empty dict if the directory is empty. We want these
|
||||
# placeholders for empty dirs so we can inform the user later.
|
||||
for path in paths:
|
||||
root = os.path.dirname(path)
|
||||
relpath = os.path.relpath(path, root)
|
||||
# Wrap the path in a list since that is what the `_get_tree` expects.
|
||||
res[relpath] = _get_tree([path], root)
|
||||
if res[relpath]:
|
||||
# We need to pull the inner path out one for these root dirs.
|
||||
res[relpath] = res[relpath][relpath]
|
||||
|
||||
return res
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||
|
||||
@@ -80,7 +135,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
||||
item = page.items.get(name)
|
||||
|
||||
page.read_user_metadata(item)
|
||||
item_html = page.create_html_for_item(item, tabname)
|
||||
item_html = page.create_item_html(tabname, item)
|
||||
|
||||
return JSONResponse({"html": item_html})
|
||||
|
||||
@@ -96,24 +151,31 @@ def quote_js(s):
|
||||
s = s.replace('"', '\\"')
|
||||
return f'"{s}"'
|
||||
|
||||
|
||||
class ExtraNetworksPage:
|
||||
def __init__(self, title):
|
||||
self.title = title
|
||||
self.name = title.lower()
|
||||
self.id_page = self.name.replace(" ", "_")
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
# This is the actual name of the extra networks tab (not txt2img/img2img).
|
||||
self.extra_networks_tabname = self.name.replace(" ", "_")
|
||||
self.allow_prompt = True
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
self.items = {}
|
||||
self.lister = util.MassFileLister()
|
||||
# HTML Templates
|
||||
self.pane_tpl = shared.html("extra-networks-pane.html")
|
||||
self.card_tpl = shared.html("extra-networks-card.html")
|
||||
self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
|
||||
self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
|
||||
self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html")
|
||||
self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html")
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def read_user_metadata(self, item):
|
||||
filename = item.get("filename", None)
|
||||
metadata = extra_networks.get_user_metadata(filename)
|
||||
metadata = extra_networks.get_user_metadata(filename, lister=self.lister)
|
||||
|
||||
desc = metadata.get("description", None)
|
||||
if desc is not None:
|
||||
@@ -123,117 +185,74 @@ class ExtraNetworksPage:
|
||||
|
||||
def link_preview(self, filename):
|
||||
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
||||
mtime = os.path.getmtime(filename)
|
||||
mtime, _ = self.lister.mctime(filename)
|
||||
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
||||
|
||||
def search_terms_from_path(self, filename, possible_directories=None):
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
||||
parentdir = os.path.abspath(parentdir)
|
||||
parentdir = os.path.dirname(os.path.abspath(parentdir))
|
||||
if abspath.startswith(parentdir):
|
||||
return abspath[len(parentdir):].replace('\\', '/')
|
||||
return os.path.relpath(abspath, parentdir)
|
||||
|
||||
return ""
|
||||
|
||||
def create_html(self, tabname):
|
||||
items_html = ''
|
||||
def create_item_html(
|
||||
self,
|
||||
tabname: str,
|
||||
item: dict,
|
||||
template: Optional[str] = None,
|
||||
) -> Union[str, dict]:
|
||||
"""Generates HTML for a single ExtraNetworks Item.
|
||||
|
||||
self.metadata = {}
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
item: Dictionary containing item information.
|
||||
template: Optional template string to use.
|
||||
|
||||
subdirs = {}
|
||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
||||
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
||||
x = os.path.join(root, dirname)
|
||||
|
||||
if not os.path.isdir(x):
|
||||
continue
|
||||
|
||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||
|
||||
if shared.opts.extra_networks_dir_button_function:
|
||||
if not subdir.startswith("/"):
|
||||
subdir = "/" + subdir
|
||||
else:
|
||||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
is_empty = len(os.listdir(x)) == 0
|
||||
if not is_empty and not subdir.endswith("/"):
|
||||
subdir = subdir + "/"
|
||||
|
||||
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
||||
continue
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
subdirs = {"": 1, **subdirs}
|
||||
|
||||
subdirs_html = "".join([f"""
|
||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
|
||||
{html.escape(subdir if subdir!="" else "all")}
|
||||
</button>
|
||||
""" for subdir in subdirs])
|
||||
|
||||
self.items = {x["name"]: x for x in self.list_items()}
|
||||
for item in self.items.values():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
if "user_metadata" not in item:
|
||||
self.read_user_metadata(item)
|
||||
|
||||
items_html += self.create_html_for_item(item, tabname)
|
||||
|
||||
if items_html == '':
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
self_name_id = self.name.replace(" ", "_")
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
return res
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return []
|
||||
|
||||
def create_html_for_item(self, item, tabname):
|
||||
Returns:
|
||||
If a template is passed: HTML string generated for this item.
|
||||
Can be empty if the item is not meant to be shown.
|
||||
If no template is passed: A dictionary containing the generated item's attributes.
|
||||
"""
|
||||
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
||||
"""
|
||||
|
||||
preview = item.get("preview", None)
|
||||
style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;"
|
||||
card_style = style_height + style_width + style_font_size
|
||||
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
||||
|
||||
onclick = item.get("onclick", None)
|
||||
if onclick is None:
|
||||
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
||||
# Don't quote prompt/neg_prompt since they are stored as js strings already.
|
||||
onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});"
|
||||
onclick = onclick_js_tpl.format(
|
||||
**{
|
||||
"tabname": tabname,
|
||||
"prompt": item["prompt"],
|
||||
"neg_prompt": item.get("negative_prompt", "''"),
|
||||
"allow_neg": str(self.allow_negative_prompt).lower(),
|
||||
}
|
||||
)
|
||||
onclick = html.escape(onclick)
|
||||
|
||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
||||
metadata_button = ""
|
||||
btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]})
|
||||
btn_metadata = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||
|
||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||
btn_metadata = self.btn_metadata_tpl.format(
|
||||
**{
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
"name": html.escape(item["name"]),
|
||||
}
|
||||
)
|
||||
btn_edit_item = self.btn_edit_item_tpl.format(
|
||||
**{
|
||||
"tabname": tabname,
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
"name": html.escape(item["name"]),
|
||||
}
|
||||
)
|
||||
|
||||
local_path = ""
|
||||
filename = item.get("filename", "")
|
||||
@@ -253,36 +272,292 @@ class ExtraNetworksPage:
|
||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||
return ""
|
||||
|
||||
sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
|
||||
sort_keys = " ".join(
|
||||
[
|
||||
f'data-sort-{k}="{html.escape(str(v))}"'
|
||||
for k, v in item.get("sort_keys", {}).items()
|
||||
]
|
||||
).strip()
|
||||
|
||||
search_terms_html = ""
|
||||
search_term_template = "<span class='hidden {class}'>{search_term}</span>"
|
||||
for search_term in item.get("search_terms", []):
|
||||
search_terms_html += search_term_template.format(
|
||||
**{
|
||||
"class": f"search_terms{' search_only' if search_only else ''}",
|
||||
"search_term": search_term,
|
||||
}
|
||||
)
|
||||
|
||||
# Some items here might not be used depending on HTML template used.
|
||||
args = {
|
||||
"background_image": background_image,
|
||||
"style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
|
||||
"prompt": item.get("prompt", None),
|
||||
"tabname": quote_js(tabname),
|
||||
"local_preview": quote_js(item["local_preview"]),
|
||||
"name": html.escape(item["name"]),
|
||||
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||
"card_clicked": onclick,
|
||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
||||
"search_term": item.get("search_term", ""),
|
||||
"metadata_button": metadata_button,
|
||||
"edit_button": edit_button,
|
||||
"copy_path_button": btn_copy_path,
|
||||
"description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||
"edit_button": btn_edit_item,
|
||||
"local_preview": quote_js(item["local_preview"]),
|
||||
"metadata_button": btn_metadata,
|
||||
"name": html.escape(item["name"]),
|
||||
"prompt": item.get("prompt", None),
|
||||
"save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"),
|
||||
"search_only": " search_only" if search_only else "",
|
||||
"search_terms": search_terms_html,
|
||||
"sort_keys": sort_keys,
|
||||
"style": card_style,
|
||||
"tabname": tabname,
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
}
|
||||
|
||||
return self.card_page.format(**args)
|
||||
if template:
|
||||
return template.format(**args)
|
||||
else:
|
||||
return args
|
||||
|
||||
def create_tree_dir_item_html(
|
||||
self,
|
||||
tabname: str,
|
||||
dir_path: str,
|
||||
content: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Generates HTML for a directory item in the tree.
|
||||
|
||||
The generated HTML is of the format:
|
||||
```html
|
||||
<li class="tree-list-item tree-list-item--has-subitem">
|
||||
<div class="tree-list-content tree-list-content-dir"></div>
|
||||
<ul class="tree-list tree-list--subgroup">
|
||||
{content}
|
||||
</ul>
|
||||
</li>
|
||||
```
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
dir_path: Path to the directory for this item.
|
||||
content: Optional HTML string that will be wrapped by this <ul>.
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
"""
|
||||
if not content:
|
||||
return None
|
||||
|
||||
btn = self.btn_tree_tpl.format(
|
||||
**{
|
||||
"search_terms": "",
|
||||
"subclass": "tree-list-content-dir",
|
||||
"tabname": tabname,
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
"onclick_extra": "",
|
||||
"data_path": dir_path,
|
||||
"data_hash": "",
|
||||
"action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
|
||||
"action_list_item_visual_leading": "🗀",
|
||||
"action_list_item_label": os.path.basename(dir_path),
|
||||
"action_list_item_visual_trailing": "",
|
||||
"action_list_item_action_trailing": "",
|
||||
}
|
||||
)
|
||||
ul = f"<ul class='tree-list tree-list--subgroup' hidden>{content}</ul>"
|
||||
return (
|
||||
"<li class='tree-list-item tree-list-item--has-subitem' data-tree-entry-type='dir'>"
|
||||
f"{btn}{ul}"
|
||||
"</li>"
|
||||
)
|
||||
|
||||
def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str:
|
||||
"""Generates HTML for a file item in the tree.
|
||||
|
||||
The generated HTML is of the format:
|
||||
```html
|
||||
<li class="tree-list-item tree-list-item--subitem">
|
||||
<span data-filterable-item-text hidden></span>
|
||||
<div class="tree-list-content tree-list-content-file"></div>
|
||||
</li>
|
||||
```
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
file_path: The path to the file for this item.
|
||||
item: Dictionary containing the item information.
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
"""
|
||||
item_html_args = self.create_item_html(tabname, item)
|
||||
action_buttons = "".join(
|
||||
[
|
||||
item_html_args["copy_path_button"],
|
||||
item_html_args["metadata_button"],
|
||||
item_html_args["edit_button"],
|
||||
]
|
||||
)
|
||||
action_buttons = f"<div class=\"button-row\">{action_buttons}</div>"
|
||||
btn = self.btn_tree_tpl.format(
|
||||
**{
|
||||
"search_terms": "",
|
||||
"subclass": "tree-list-content-file",
|
||||
"tabname": tabname,
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
"onclick_extra": item_html_args["card_clicked"],
|
||||
"data_path": file_path,
|
||||
"data_hash": item["shorthash"],
|
||||
"action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
|
||||
"action_list_item_visual_leading": "🗎",
|
||||
"action_list_item_label": item["name"],
|
||||
"action_list_item_visual_trailing": "",
|
||||
"action_list_item_action_trailing": action_buttons,
|
||||
}
|
||||
)
|
||||
return (
|
||||
"<li class='tree-list-item tree-list-item--subitem' data-tree-entry-type='file'>"
|
||||
f"{btn}"
|
||||
"</li>"
|
||||
)
|
||||
|
||||
def create_tree_view_html(self, tabname: str) -> str:
|
||||
"""Generates HTML for displaying folders in a tree view.
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
|
||||
Returns:
|
||||
HTML string generated for this tree view.
|
||||
"""
|
||||
res = ""
|
||||
|
||||
# Setup the tree dictionary.
|
||||
roots = self.allowed_directories_for_previews()
|
||||
tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()}
|
||||
tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items)
|
||||
|
||||
if not tree:
|
||||
return res
|
||||
|
||||
def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]:
|
||||
"""Recursively builds HTML for a tree.
|
||||
|
||||
Args:
|
||||
data: Dictionary representing a directory tree. Can be NoneType.
|
||||
Data keys should be absolute paths from the root and values
|
||||
should be subdirectory trees or an ExtraNetworksItem.
|
||||
|
||||
Returns:
|
||||
If data is not None: HTML string
|
||||
Else: None
|
||||
"""
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# Lists for storing <li> items html for directories and files separately.
|
||||
_dir_li = []
|
||||
_file_li = []
|
||||
|
||||
for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])):
|
||||
if isinstance(v, (ExtraNetworksItem,)):
|
||||
_file_li.append(self.create_tree_file_item_html(tabname, k, v.item))
|
||||
else:
|
||||
_dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v)))
|
||||
|
||||
# Directories should always be displayed before files so we order them here.
|
||||
return "".join(_dir_li) + "".join(_file_li)
|
||||
|
||||
# Add each root directory to the tree.
|
||||
for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])):
|
||||
item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v))
|
||||
# Only add non-empty entries to the tree.
|
||||
if item_html is not None:
|
||||
res += item_html
|
||||
|
||||
return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
|
||||
|
||||
def create_card_view_html(self, tabname: str) -> str:
|
||||
"""Generates HTML for the network Card View section for a tab.
|
||||
|
||||
This HTML goes into the `extra-networks-pane.html` <div> with
|
||||
`id='{tabname}_{extra_networks_tabname}_cards`.
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
"""
|
||||
res = ""
|
||||
for item in self.items.values():
|
||||
res += self.create_item_html(tabname, item, self.card_tpl)
|
||||
|
||||
if res == "":
|
||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
res = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
return res
|
||||
|
||||
def create_html(self, tabname):
|
||||
"""Generates an HTML string for the current pane.
|
||||
|
||||
The generated HTML uses `extra-networks-pane.html` as a template.
|
||||
|
||||
Args:
|
||||
tabname: The name of the active tab.
|
||||
|
||||
Returns:
|
||||
HTML formatted string.
|
||||
"""
|
||||
self.lister.reset()
|
||||
self.metadata = {}
|
||||
self.items = {x["name"]: x for x in self.list_items()}
|
||||
# Populate the instance metadata for each item.
|
||||
for item in self.items.values():
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
self.metadata[item["name"]] = metadata
|
||||
|
||||
if "user_metadata" not in item:
|
||||
self.read_user_metadata(item)
|
||||
|
||||
data_sortdir = shared.opts.extra_networks_card_order
|
||||
data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
|
||||
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
|
||||
tree_view_btn_extra_class = ""
|
||||
tree_view_div_extra_class = "hidden"
|
||||
if shared.opts.extra_networks_tree_view_default_enabled:
|
||||
tree_view_btn_extra_class = "extra-network-control--enabled"
|
||||
tree_view_div_extra_class = ""
|
||||
|
||||
return self.pane_tpl.format(
|
||||
**{
|
||||
"tabname": tabname,
|
||||
"extra_networks_tabname": self.extra_networks_tabname,
|
||||
"data_sortmode": data_sortmode,
|
||||
"data_sortkey": data_sortkey,
|
||||
"data_sortdir": data_sortdir,
|
||||
"tree_view_btn_extra_class": tree_view_btn_extra_class,
|
||||
"tree_view_div_extra_class": tree_view_div_extra_class,
|
||||
"tree_html": self.create_tree_view_html(tabname),
|
||||
"items_html": self.create_card_view_html(tabname),
|
||||
}
|
||||
)
|
||||
|
||||
def create_item(self, name, index=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def list_items(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return []
|
||||
|
||||
def get_sort_keys(self, path):
|
||||
"""
|
||||
List of default keys used for sorting in the UI.
|
||||
"""
|
||||
pth = Path(path)
|
||||
stat = pth.stat()
|
||||
mtime, ctime = self.lister.mctime(path)
|
||||
return {
|
||||
"date_created": int(stat.st_ctime or 0),
|
||||
"date_modified": int(stat.st_mtime or 0),
|
||||
"date_created": int(mtime),
|
||||
"date_modified": int(ctime),
|
||||
"name": pth.name.lower(),
|
||||
"path": str(pth.parent).lower(),
|
||||
}
|
||||
@@ -292,10 +567,10 @@ class ExtraNetworksPage:
|
||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||
"""
|
||||
|
||||
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
|
||||
potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], [])
|
||||
|
||||
for file in potential_files:
|
||||
if os.path.isfile(file):
|
||||
if self.lister.exists(file):
|
||||
return self.link_preview(file)
|
||||
|
||||
return None
|
||||
@@ -305,6 +580,9 @@ class ExtraNetworksPage:
|
||||
Find and read a description file for a given path (without extension).
|
||||
"""
|
||||
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
||||
if not self.lister.exists(file):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
@@ -360,10 +638,7 @@ def pages_in_preferred_order(pages):
|
||||
|
||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||
|
||||
|
||||
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
from modules.ui import switch_values_symbol
|
||||
|
||||
ui = ExtraNetworksUi()
|
||||
ui.pages = []
|
||||
ui.pages_contents = []
|
||||
@@ -373,62 +648,53 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
|
||||
related_tabs = []
|
||||
|
||||
button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False)
|
||||
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
||||
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
|
||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab:
|
||||
with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]):
|
||||
pass
|
||||
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
|
||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
||||
|
||||
editor = page.create_user_metadata_editor(ui, tabname)
|
||||
editor.create_ui()
|
||||
ui.user_metadata_editors.append(editor)
|
||||
|
||||
related_tabs.append(tab)
|
||||
|
||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False)
|
||||
|
||||
for tab in unrelated_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||
tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
||||
allow_prompt = "true" if page.allow_prompt else "false"
|
||||
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
||||
jscode = (
|
||||
"function(){{"
|
||||
f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()}, '{tabname}_{page.extra_networks_tabname}');"
|
||||
f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');"
|
||||
"}}"
|
||||
)
|
||||
tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)
|
||||
|
||||
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||
|
||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||
|
||||
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
|
||||
def create_html():
|
||||
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||
|
||||
def pages_html():
|
||||
if not ui.pages_contents:
|
||||
return refresh()
|
||||
|
||||
create_html()
|
||||
return ui.pages_contents
|
||||
|
||||
def refresh():
|
||||
for pg in ui.stored_extra_pages:
|
||||
pg.refresh()
|
||||
|
||||
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||
|
||||
create_html()
|
||||
return ui.pages_contents
|
||||
|
||||
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
|
||||
interface.load(fn=pages_html, inputs=[], outputs=ui.pages)
|
||||
# NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick()
|
||||
# button is unused and hidden at all times. Only used in order to fire this event.
|
||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||
|
||||
return ui
|
||||
@@ -478,5 +744,3 @@ def setup_ui(ui, gallery):
|
||||
|
||||
for editor in ui.user_metadata_editors:
|
||||
editor.setup_ui(gallery)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import html
|
||||
import os
|
||||
|
||||
from modules import shared, ui_extra_networks, sd_models
|
||||
from modules.ui_extra_networks import quote_js
|
||||
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
||||
|
||||
|
||||
@@ -21,14 +20,17 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
return
|
||||
|
||||
path, ext = os.path.splitext(checkpoint.filename)
|
||||
search_terms = [self.search_terms_from_path(checkpoint.filename)]
|
||||
if checkpoint.sha256:
|
||||
search_terms.append(checkpoint.sha256)
|
||||
return {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": checkpoint.filename,
|
||||
"shorthash": checkpoint.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||
"search_terms": search_terms,
|
||||
"onclick": html.escape(f"return selectCheckpoint('{name}');"),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": checkpoint.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||
|
||||
@@ -20,14 +20,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
path, ext = os.path.splitext(full_path)
|
||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||
shorthash = sha256[0:10] if sha256 else None
|
||||
|
||||
search_terms = [self.search_terms_from_path(path)]
|
||||
if sha256:
|
||||
search_terms.append(sha256)
|
||||
return {
|
||||
"name": name,
|
||||
"filename": full_path,
|
||||
"shorthash": shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
|
||||
"search_terms": search_terms,
|
||||
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||
|
||||
@@ -18,13 +18,16 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
return
|
||||
|
||||
path, ext = os.path.splitext(embedding.filename)
|
||||
search_terms = [self.search_terms_from_path(embedding.filename)]
|
||||
if embedding.hash:
|
||||
search_terms.append(embedding.hash)
|
||||
return {
|
||||
"name": name,
|
||||
"filename": embedding.filename,
|
||||
"shorthash": embedding.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
|
||||
"search_terms": search_terms,
|
||||
"prompt": quote_js(embedding.name),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||
|
||||
@@ -5,7 +5,7 @@ import os.path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
|
||||
from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks
|
||||
|
||||
|
||||
class UserMetadataEditor:
|
||||
@@ -14,7 +14,7 @@ class UserMetadataEditor:
|
||||
self.ui = ui
|
||||
self.tabname = tabname
|
||||
self.page = page
|
||||
self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
|
||||
self.id_part = f"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata"
|
||||
|
||||
self.box = None
|
||||
|
||||
@@ -181,7 +181,7 @@ class UserMetadataEditor:
|
||||
index = len(gallery) - 1 if index >= len(gallery) else index
|
||||
|
||||
img_info = gallery[index if index >= 0 else 0]
|
||||
image = generation_parameters_copypaste.image_from_url_text(img_info)
|
||||
image = infotext_utils.image_from_url_text(img_info)
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
|
||||
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, shared, scripts
|
||||
from modules.paths import script_path, data_path, cwd
|
||||
from modules import localization, shared, scripts, util
|
||||
from modules.paths import script_path, data_path
|
||||
|
||||
|
||||
def webpath(fn):
|
||||
if fn.startswith(cwd):
|
||||
web_path = os.path.relpath(fn, cwd)
|
||||
else:
|
||||
web_path = os.path.abspath(fn)
|
||||
|
||||
return f'file={web_path}?{os.path.getmtime(fn)}'
|
||||
return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'
|
||||
|
||||
|
||||
def javascript_html():
|
||||
@@ -40,13 +35,11 @@ def css_html():
|
||||
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
||||
|
||||
for cssfile in scripts.list_files_with_name("style.css"):
|
||||
if not os.path.isfile(cssfile):
|
||||
continue
|
||||
|
||||
head += stylesheet(cssfile)
|
||||
|
||||
if os.path.exists(os.path.join(data_path, "user.css")):
|
||||
head += stylesheet(os.path.join(data_path, "user.css"))
|
||||
user_css = os.path.join(data_path, "user.css")
|
||||
if os.path.exists(user_css):
|
||||
head += stylesheet(user_css)
|
||||
|
||||
return head
|
||||
|
||||
|
||||
@@ -26,8 +26,9 @@ class UiLoadsave:
|
||||
self.ui_defaults_review = None
|
||||
|
||||
try:
|
||||
if os.path.exists(self.filename):
|
||||
self.ui_settings = self.read_from_file()
|
||||
self.ui_settings = self.read_from_file()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.error_loading = True
|
||||
errors.display(e, "loading settings")
|
||||
@@ -144,7 +145,7 @@ class UiLoadsave:
|
||||
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
|
||||
|
||||
def dump_defaults(self):
|
||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||
"""saves default values to a file unless the file is present and there was an error loading default values at start"""
|
||||
|
||||
if self.error_loading and os.path.exists(self.filename):
|
||||
return
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import gradio as gr
|
||||
from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.infotext_utils as parameters_copypaste
|
||||
from modules.ui_components import ResizeHandleRow
|
||||
|
||||
|
||||
def create_ui():
|
||||
dummy_component = gr.Label(visible=False)
|
||||
tab_index = gr.State(value=0)
|
||||
tab_index = gr.Number(value=0, visible=False)
|
||||
|
||||
with gr.Row(equal_height=False, variant='compact'):
|
||||
with ResizeHandleRow(equal_height=False, variant='compact'):
|
||||
with gr.Column(variant='compact'):
|
||||
with gr.Tabs(elem_id="mode_extras"):
|
||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||
@@ -28,7 +29,7 @@ def create_ui():
|
||||
toprow.create_inline_toprow_image()
|
||||
submit = toprow.submit
|
||||
|
||||
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
||||
output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
||||
|
||||
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
||||
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
||||
@@ -48,9 +49,9 @@ def create_ui():
|
||||
*script_inputs
|
||||
],
|
||||
outputs=[
|
||||
result_images,
|
||||
html_info_x,
|
||||
html_log,
|
||||
output_panel.gallery,
|
||||
output_panel.generation_info,
|
||||
output_panel.html_log,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
|
||||
if not name:
|
||||
return gr.update(visible=False)
|
||||
|
||||
style = styles.PromptStyle(name, prompt, negative_prompt)
|
||||
existing_style = shared.prompt_styles.styles.get(name)
|
||||
path = existing_style.path if existing_style is not None else None
|
||||
|
||||
style = styles.PromptStyle(name, prompt, negative_prompt, path)
|
||||
shared.prompt_styles.styles[style.name] = style
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
shared.prompt_styles.save_styles()
|
||||
|
||||
return gr.update(visible=True)
|
||||
|
||||
@@ -34,7 +37,7 @@ def delete_style(name):
|
||||
return
|
||||
|
||||
shared.prompt_styles.styles.pop(name, None)
|
||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||
shared.prompt_styles.save_styles()
|
||||
|
||||
return '', '', ''
|
||||
|
||||
|
||||
@@ -79,11 +79,11 @@ class Toprow:
|
||||
def create_prompts(self):
|
||||
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
||||
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
|
||||
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
|
||||
|
||||
self.prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
@@ -106,8 +106,15 @@ class Toprow:
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def interrupt_function():
|
||||
if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:
|
||||
shared.state.stop_generating()
|
||||
gr.Info("Generation will stop after finishing this image, click again to stop immediately.")
|
||||
else:
|
||||
shared.state.interrupt()
|
||||
|
||||
self.interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
fn=interrupt_function,
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@@ -98,6 +98,9 @@ class UpscalerData:
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
|
||||
189
modules/upscaler_utils.py
Normal file
189
modules/upscaler_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from modules import images, shared, torch_utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
|
||||
img = np.array(img.convert("RGB"))
|
||||
img = img[:, :, ::-1] # flip RGB to BGR
|
||||
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
|
||||
img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
|
||||
return torch.from_numpy(img)
|
||||
|
||||
|
||||
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
|
||||
if tensor.ndim == 4:
|
||||
# If we're given a tensor with a batch dimension, squeeze it out
|
||||
# (but only if it's a batch of size 1).
|
||||
if tensor.shape[0] != 1:
|
||||
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
|
||||
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
|
||||
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
|
||||
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
|
||||
arr = arr.round().astype(np.uint8)
|
||||
arr = arr[:, :, ::-1] # flip BGR to RGB
|
||||
return Image.fromarray(arr, "RGB")
|
||||
|
||||
|
||||
def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Upscale a given PIL image using the given model.
|
||||
"""
|
||||
param = torch_utils.get_param(model)
|
||||
|
||||
with torch.no_grad():
|
||||
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
|
||||
tensor = tensor.to(device=param.device, dtype=param.dtype)
|
||||
return torch_bgr_to_pil_image(model(tensor))
|
||||
|
||||
|
||||
def upscale_with_model(
|
||||
model: Callable[[torch.Tensor], torch.Tensor],
|
||||
img: Image.Image,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int = 0,
|
||||
desc="tiled upscale",
|
||||
) -> Image.Image:
|
||||
if tile_size <= 0:
|
||||
logger.debug("Upscaling %s without tiling", img)
|
||||
output = upscale_pil_patch(model, img)
|
||||
logger.debug("=> %s", output)
|
||||
return output
|
||||
|
||||
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
|
||||
newtiles = []
|
||||
|
||||
with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p:
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for x, w, tile in row:
|
||||
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
||||
output = upscale_pil_patch(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
p.update(1)
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(
|
||||
newtiles,
|
||||
tile_w=grid.tile_w * scale_factor,
|
||||
tile_h=grid.tile_h * scale_factor,
|
||||
image_w=grid.image_w * scale_factor,
|
||||
image_h=grid.image_h * scale_factor,
|
||||
overlap=grid.overlap * scale_factor,
|
||||
)
|
||||
return images.combine_grid(newgrid)
|
||||
|
||||
|
||||
def tiled_upscale_2(
|
||||
img: torch.Tensor,
|
||||
model,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
scale: int,
|
||||
device: torch.device,
|
||||
desc="Tiled upscale",
|
||||
):
|
||||
# Alternative implementation of `upscale_with_model` originally used by
|
||||
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
|
||||
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
||||
# Pillow space without weighting.
|
||||
|
||||
b, c, h, w = img.size()
|
||||
tile_size = min(tile_size, h, w)
|
||||
|
||||
if tile_size <= 0:
|
||||
logger.debug("Upscaling %s without tiling", img.shape)
|
||||
return model(img)
|
||||
|
||||
stride = tile_size - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
|
||||
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
|
||||
result = torch.zeros(
|
||||
b,
|
||||
c,
|
||||
h * scale,
|
||||
w * scale,
|
||||
device=device,
|
||||
dtype=img.dtype,
|
||||
)
|
||||
weights = torch.zeros_like(result)
|
||||
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
|
||||
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
break
|
||||
|
||||
for w_idx in w_idx_list:
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
break
|
||||
|
||||
# Only move this patch to the device if it's not already there.
|
||||
in_patch = img[
|
||||
...,
|
||||
h_idx : h_idx + tile_size,
|
||||
w_idx : w_idx + tile_size,
|
||||
].to(device=device)
|
||||
|
||||
out_patch = model(in_patch)
|
||||
|
||||
result[
|
||||
...,
|
||||
h_idx * scale : (h_idx + tile_size) * scale,
|
||||
w_idx * scale : (w_idx + tile_size) * scale,
|
||||
].add_(out_patch)
|
||||
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
weights[
|
||||
...,
|
||||
h_idx * scale : (h_idx + tile_size) * scale,
|
||||
w_idx * scale : (w_idx + tile_size) * scale,
|
||||
].add_(out_patch_mask)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
output = result.div_(weights)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def upscale_2(
|
||||
img: Image.Image,
|
||||
model,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
scale: int,
|
||||
desc: str,
|
||||
):
|
||||
"""
|
||||
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
||||
"""
|
||||
param = torch_utils.get_param(model)
|
||||
tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
|
||||
|
||||
with torch.no_grad():
|
||||
output = tiled_upscale_2(
|
||||
tensor,
|
||||
model,
|
||||
tile_size=tile_size,
|
||||
tile_overlap=tile_overlap,
|
||||
scale=scale,
|
||||
desc=desc,
|
||||
device=param.device,
|
||||
)
|
||||
return torch_bgr_to_pil_image(output)
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import re
|
||||
|
||||
from modules import shared
|
||||
from modules.paths_internal import script_path
|
||||
from modules.paths_internal import script_path, cwd
|
||||
|
||||
|
||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||
@@ -21,11 +21,11 @@ def html_path(filename):
|
||||
def html(filename):
|
||||
path = html_path(filename)
|
||||
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
with open(path, encoding="utf8") as file:
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
|
||||
def walk_files(path, allowed_extensions=None):
|
||||
@@ -56,3 +56,83 @@ def ldm_print(*args, **kwargs):
|
||||
return
|
||||
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def truncate_path(target_path, base_path=cwd):
|
||||
abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
|
||||
try:
|
||||
if os.path.commonpath([abs_target, abs_base]) == abs_base:
|
||||
return os.path.relpath(abs_target, abs_base)
|
||||
except ValueError:
|
||||
pass
|
||||
return abs_target
|
||||
|
||||
|
||||
class MassFileListerCachedDir:
|
||||
"""A class that caches file metadata for a specific directory."""
|
||||
|
||||
def __init__(self, dirname):
|
||||
self.files = None
|
||||
self.files_cased = None
|
||||
self.dirname = dirname
|
||||
|
||||
stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))
|
||||
files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]
|
||||
self.files = {x[0].lower(): x for x in files}
|
||||
self.files_cased = {x[0]: x for x in files}
|
||||
|
||||
|
||||
class MassFileLister:
|
||||
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
|
||||
|
||||
def __init__(self):
|
||||
self.cached_dirs = {}
|
||||
|
||||
def find(self, path):
|
||||
"""
|
||||
Find the metadata for a file at the given path.
|
||||
|
||||
Returns:
|
||||
tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.
|
||||
"""
|
||||
|
||||
dirname, filename = os.path.split(path)
|
||||
|
||||
cached_dir = self.cached_dirs.get(dirname)
|
||||
if cached_dir is None:
|
||||
cached_dir = MassFileListerCachedDir(dirname)
|
||||
self.cached_dirs[dirname] = cached_dir
|
||||
|
||||
stats = cached_dir.files_cased.get(filename)
|
||||
if stats is not None:
|
||||
return stats
|
||||
|
||||
stats = cached_dir.files.get(filename.lower())
|
||||
if stats is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
os_stats = os.stat(path, follow_symlinks=False)
|
||||
return filename, os_stats.st_mtime, os_stats.st_ctime
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def exists(self, path):
|
||||
"""Check if a file exists at the given path."""
|
||||
|
||||
return self.find(path) is not None
|
||||
|
||||
def mctime(self, path):
|
||||
"""
|
||||
Get the modification and creation times for a file at the given path.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.
|
||||
"""
|
||||
|
||||
stats = self.find(path)
|
||||
return (0, 0) if stats is None else stats[1:3]
|
||||
|
||||
def reset(self):
|
||||
"""Clear the cache of all directories."""
|
||||
self.cached_dirs.clear()
|
||||
|
||||
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
||||
@@ -4,6 +4,8 @@ import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
from modules import torch_utils
|
||||
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
@@ -68,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
device = torch_utils.get_param(self).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
|
||||
@@ -27,11 +27,90 @@ def torch_xpu_gc():
|
||||
|
||||
has_xpu = check_for_xpu()
|
||||
|
||||
|
||||
# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
|
||||
# Here we implement a slicing algorithm to split large batch size into smaller chunks,
|
||||
# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
|
||||
# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
|
||||
# which is the best trade-off between VRAM usage and performance.
|
||||
ARC_SINGLE_ALLOCATION_LIMIT = {}
|
||||
orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
|
||||
def torch_xpu_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
|
||||
):
|
||||
# cast to same dtype first
|
||||
key = key.to(query.dtype)
|
||||
value = value.to(query.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(query.dtype)
|
||||
|
||||
N = query.shape[:-2] # Batch size
|
||||
L = query.size(-2) # Target sequence length
|
||||
E = query.size(-1) # Embedding dimension of the query and key
|
||||
S = key.size(-2) # Source sequence length
|
||||
Ev = value.size(-1) # Embedding dimension of the value
|
||||
|
||||
total_batch_size = torch.numel(torch.empty(N))
|
||||
device_id = query.device.index
|
||||
if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
|
||||
ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
|
||||
batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
|
||||
|
||||
if total_batch_size <= batch_size_limit:
|
||||
return orig_sdp_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
query = torch.reshape(query, (-1, L, E))
|
||||
key = torch.reshape(key, (-1, S, E))
|
||||
value = torch.reshape(value, (-1, S, Ev))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(-1, L, S)
|
||||
chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
|
||||
outputs = []
|
||||
for i in range(chunk_count):
|
||||
attn_mask_chunk = (
|
||||
None
|
||||
if attn_mask is None
|
||||
else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
|
||||
)
|
||||
chunk_output = orig_sdp_attn_func(
|
||||
query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
|
||||
attn_mask_chunk,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
*args, **kwargs
|
||||
)
|
||||
outputs.append(chunk_output)
|
||||
result = torch.cat(outputs, dim=0)
|
||||
return torch.reshape(result, (*N, L, Ev))
|
||||
|
||||
|
||||
def is_xpu_device(device: str | torch.device = None):
|
||||
if device is None:
|
||||
return False
|
||||
if isinstance(device, str):
|
||||
return device.startswith("xpu")
|
||||
return device.type == "xpu"
|
||||
|
||||
|
||||
if has_xpu:
|
||||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
lambda orig_func, device=None: device is not None and device.type == "xpu")
|
||||
try:
|
||||
# torch.Generator supports "xpu" device since 2.1
|
||||
torch.Generator("xpu")
|
||||
except RuntimeError:
|
||||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
lambda orig_func, device=None: is_xpu_device(device))
|
||||
|
||||
# W/A for some OPs that could not handle different input dtypes
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
@@ -55,5 +134,5 @@ if has_xpu:
|
||||
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
|
||||
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
|
||||
lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
|
||||
lambda orig_func, query, *args, **kwargs: query.is_xpu)
|
||||
|
||||
Reference in New Issue
Block a user