Merge branch 'master' into master

This commit is contained in:
KyuSeok Jung
2022-11-06 03:08:45 +09:00
committed by GitHub
40 changed files with 1757 additions and 540 deletions

View File

@@ -2,15 +2,18 @@ import base64
import io
import time
import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
def upscaler_to_index(name: str):
try:
@@ -32,13 +35,25 @@ def setUpscalers(req: dict):
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
with io.BytesIO() as output_bytes:
# Copy any text-only metadata
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
class Api:
def __init__(self, app, queue_lock):
def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
@@ -49,6 +64,18 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -179,6 +206,8 @@ class Api:
progress = min(progress, 1)
shared.state.set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
@@ -190,6 +219,70 @@ class Api:
return {}
def get_config(self):
options = {}
for key in shared.opts.data.keys():
metadata = shared.opts.data_labels.get(key)
if(metadata is not None):
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
return options
def set_config(self, req: OptionsModel):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
# overwrite all options with default values.
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
shared.opts.save(shared.config_filename)
return
def get_cmd_flags(self):
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
def get_upscalers(self):
upscalers = []
for upscaler in shared.sd_upscalers:
u = upscaler.scaler
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
return upscalers
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
def get_face_restorers(self):
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

View File

@@ -1,11 +1,11 @@
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers
from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
API_NOT_ALLOWED = [
"self",
@@ -109,12 +109,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model()
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
@@ -131,6 +131,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
@@ -146,10 +147,10 @@ class FileData(BaseModel):
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
@@ -165,3 +166,68 @@ class ProgressResponse(BaseModel):
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
fields = {}
for key, value in opts.data.items():
metadata = opts.data_labels.get(key)
optType = opts.typemap.get(type(value), type(value))
if (metadata is not None):
fields.update({key: (Optional[optType], Field(
default=metadata.default ,description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
OptionsModel = create_model("Options", **fields)
flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
if _options[key].default is not None: _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
class FaceRestorerItem(BaseModel):
name: str = Field(title="Name")
cmd_dir: Optional[str] = Field(title="Path")
class RealesrganItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
scale: Optional[int] = Field(title="Scale")
class PromptStyleItem(BaseModel):
name: str = Field(title="Name")
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")

View File

@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
if 'conv_up3.weight' in state_dict:
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict

View File

@@ -34,8 +34,11 @@ class Extension:
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
except Exception:
self.remote = None
def list_files(self, subdir, extension):
from modules import scripts

View File

@@ -136,12 +136,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info),
args_hash=hash((upscale_args, upscale_first)))
args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)

View File

@@ -22,6 +22,8 @@ from collections import defaultdict, deque
from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -35,7 +37,8 @@ class HypernetworkModule(torch.nn.Module):
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
@@ -48,8 +51,8 @@ class HypernetworkModule(torch.nn.Module):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func
if activation_func == "linear" or activation_func is None:
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
@@ -60,8 +63,8 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
# Add dropout except last layer
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
@@ -75,7 +78,7 @@ class HypernetworkModule(torch.nn.Module):
w, b = layer.weight.data, layer.bias.data
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0.005)
normal_(b, mean=0.0, std=0)
elif weight_init == 'XavierUniform':
xavier_uniform_(w)
zeros_(b)
@@ -127,7 +130,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
self.filename = None
self.name = name
self.layers = {}
@@ -139,11 +142,17 @@ class Hypernetwork:
self.weight_init = weight_init
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
self.optimizer_name = None
self.optimizer_state_dict = None
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
def weights(self):
@@ -158,6 +167,7 @@ class Hypernetwork:
def save(self, filename):
state_dict = {}
optimizer_saved_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -171,8 +181,17 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename):
self.filename = filename
@@ -191,12 +210,29 @@ class Hypernetwork:
print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
)
self.name = state_dict.get('name', self.name)
@@ -207,11 +243,11 @@ class Hypernetwork:
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
res[name + f"({sd_models.model_hash(filename)})"] = filename
return res
@@ -346,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@@ -392,8 +429,22 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
# Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
steps_without_grad = 0
@@ -455,7 +506,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -518,8 +573,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):

View File

@@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.

View File

@@ -81,7 +81,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None
# Use the EXIF orientation of photos taken by smartphones.
image = ImageOps.exif_transpose(image)
if image is not None:
image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

View File

@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) * ratio_processing
desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2

View File

@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)

View File

@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch.zeros(
x.shape[0], 5, 1, 1,
dtype=x.dtype,
device=x.device
)
return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
return torch.zeros(
latent_image.shape[0], 5, 1, 1,
dtype=latent_image.dtype,
device=latent_image.device
)
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,11 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -199,7 +191,7 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds):
pass
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
@@ -426,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
opts.data[k] = v
setattr(opts, k, v)
return res
@@ -509,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -521,7 +516,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -649,7 +644,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -662,9 +657,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -674,6 +688,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -685,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
image_conditioning = self.txt2img_image_conditioning(x)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
@@ -831,8 +848,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -843,4 +859,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
return samples
return samples

View File

@@ -2,6 +2,7 @@ import sys
import traceback
from collections import namedtuple
import inspect
from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
@@ -45,25 +46,23 @@ class CFGDenoiserParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
callbacks_cfg_denoiser = []
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[]
)
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
callbacks_cfg_denoiser.clear()
for callback_list in callback_map.values():
callback_list.clear()
def app_started_callback(demo: Blocks, app: FastAPI):
for c in callbacks_app_started:
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
@@ -71,7 +70,7 @@ def app_started_callback(demo: Blocks, app: FastAPI):
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
for c in callback_map['callbacks_model_loaded']:
try:
c.callback(sd_model)
except Exception:
@@ -81,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
for c in callbacks_ui_tabs:
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
except Exception:
@@ -91,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback():
for c in callbacks_ui_settings:
for c in callback_map['callbacks_ui_settings']:
try:
c.callback()
except Exception:
@@ -99,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_before_image_saved:
for c in callback_map['callbacks_before_image_saved']:
try:
c.callback(params)
except Exception:
@@ -107,7 +106,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved:
for c in callback_map['callbacks_image_saved']:
try:
c.callback(params)
except Exception:
@@ -115,7 +114,7 @@ def image_saved_callback(params: ImageSaveParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callbacks_cfg_denoiser:
for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
@@ -128,17 +127,33 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
add_callback(callbacks_model_loaded, callback)
add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback):
@@ -151,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
add_callback(callbacks_ui_tabs, callback)
add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callbacks_ui_settings, callback)
add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback):
@@ -165,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
"""
add_callback(callbacks_before_image_saved, callback)
add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback):
@@ -173,7 +188,7 @@ def on_image_saved(callback):
The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
@@ -181,5 +196,4 @@ def on_cfg_denoiser(callback):
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callbacks_cfg_denoiser, callback)
add_callback(callback_map['callbacks_cfg_denoiser'], callback)

View File

@@ -18,6 +18,9 @@ class Script:
args_to = None
alwayson = False
"""A gr.Group component that has all script's UI inside it"""
group = None
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
@@ -70,6 +73,19 @@ class Script:
pass
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -218,8 +234,6 @@ class ScriptRunner:
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
if not script.alwayson:
control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
@@ -229,40 +243,41 @@ class ScriptRunner:
script.args_to = len(inputs)
for script in self.alwayson_scripts:
with gr.Group():
with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson)
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
create_script_ui(script, inputs, inputs_alwayson)
with gr.Group(visible=False) as group:
create_script_ui(script, inputs, inputs_alwayson)
script.group = group
def select_script(script_index):
if 0 < script_index <= len(self.selectable_scripts):
script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title):
"""called when an initial value is set from ui-config.json to show script's UI components"""
if title == 'None':
return
script_index = self.titles.index(title)
script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
self.selectable_scripts[script_index].group.visible = True
dropdown.init_field = init_field
dropdown.change(
fn=select_script,
inputs=[dropdown],
outputs=inputs
outputs=[script.group for script in self.selectable_scripts]
)
return inputs
@@ -294,6 +309,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:

View File

@@ -9,7 +9,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks
from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
@@ -159,13 +159,16 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
def load_model_weights(model, checkpoint_info):
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
@@ -182,37 +185,36 @@ def load_model_weights(model, checkpoint_info):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
if shared.opts.sd_checkpoint_cache > 0:
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
sd_vae.load_vae(model, vae_file)
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
@@ -263,7 +265,7 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model

View File

@@ -24,11 +24,15 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
@@ -93,8 +97,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample)
def sample_to_image(samples):
return single_sample_to_image(samples[0])
def sample_to_image(samples, index=0):
return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):

207
modules/sd_vae.py Normal file
View File

@@ -0,0 +1,207 @@
import torch
import os
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy()
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global base_vae, checkpoint_info
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae)
delete_base_vae()
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
vae_dict.update(res)
vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
# Last check, just because
if vae_file and not os.path.exists(vae_file):
vae_file = None
return vae_file
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False
# don't call this from outside
def load_vae_dict(model, vae_dict_1=None):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
return sd_model

View File

@@ -4,6 +4,7 @@ import json
import os
import sys
from collections import OrderedDict
import time
import gradio as gr
import tqdm
@@ -14,7 +15,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization
from modules import sd_samplers, sd_models, localization, sd_vae
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -43,6 +44,7 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -84,6 +86,10 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
cmd_opts = parser.parse_args()
restricted_opts = {
@@ -98,7 +104,7 @@ restricted_opts = {
"outdir_save",
}
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
@@ -135,6 +141,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
time_start = None
need_restart = False
def skip(self):
@@ -144,6 +151,9 @@ class State:
self.interrupted = True
def nextjob(self):
if opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
@@ -172,6 +182,7 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
devices.torch_gc()
@@ -181,6 +192,24 @@ class State:
devices.torch_gc()
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image()
def do_set_current_image(self):
if not parallel_processing_allowed:
return
if self.current_latent is None:
return
if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -238,6 +267,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
@@ -291,6 +322,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by "," when create texts."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -301,11 +333,11 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
@@ -333,7 +365,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@@ -378,6 +410,16 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = opts.data_labels.get(key, None)
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
@@ -394,6 +436,8 @@ class Options:
return super(Options, self).__getattribute__(item)
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
@@ -420,11 +464,12 @@ class Options:
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func):
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
func()
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}

View File

@@ -276,16 +276,8 @@ def check_progress_call(id_part):
image = gr_show(False)
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
if opts.show_progress_every_n_steps != 0:
shared.state.set_current_image()
image = shared.state.current_image
if image is None:
@@ -1060,9 +1052,11 @@ def create_ui(wrap_gradio_gpu_call):
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
@@ -1087,8 +1081,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
@@ -1190,8 +1182,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
@@ -1446,48 +1438,39 @@ def create_ui(wrap_gradio_gpu_call):
def run_settings(*args):
changed = 0
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component:
continue
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
continue
oldval = opts.data.get(key, None)
opts.data[key] = value
try:
setattr(opts, key, value)
except RuntimeError:
continue
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
changed += 1
opts.save(shared.config_filename)
return f'{changed} settings changed.', opts.dumpjson()
try:
opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{changed} settings changed without save.'
return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key):
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
try:
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson()
opts.data[key] = value
if oldval != value:
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
@@ -1640,9 +1623,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=run_settings,
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components,
outputs=[result, text_settings],
outputs=[text_settings, result],
)
for i, k, item in quicksettings_list:

View File

@@ -86,7 +86,7 @@ def extension_table():
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -188,7 +188,7 @@ def refresh_available_extensions_from_data():
code += f"""
<tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a></td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>

View File

@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path
@@ -56,10 +57,18 @@ class Upscaler:
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
if shape == (img.width, img.height):
break
if img.width >= dest_w and img.height >= dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
@@ -120,3 +129,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
class UpscalerNearest(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)]