mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-03-13 17:10:23 +00:00
Merge branch 'master' of https://github.com/cluder/stable-diffusion-webui into 4448_fix_ckpt_cache
This commit is contained in:
@@ -63,6 +63,7 @@ class Api:
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
@@ -214,6 +215,19 @@ class Api:
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = self.__base64_to_image(image_b64)
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
|
||||
return InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class PydanticModelGenerator:
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
@@ -167,6 +168,12 @@ class ProgressResponse(BaseModel):
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
fields = {}
|
||||
for key, value in opts.data.items():
|
||||
metadata = opts.data_labels.get(key)
|
||||
@@ -231,3 +238,4 @@ class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
|
||||
@@ -101,8 +101,8 @@ class LDSR:
|
||||
down_sample_rate = target_scale / 4
|
||||
wd = width_og * down_sample_rate
|
||||
hd = height_og * down_sample_rate
|
||||
width_downsampled_pre = int(wd)
|
||||
height_downsampled_pre = int(hd)
|
||||
width_downsampled_pre = int(np.ceil(wd))
|
||||
height_downsampled_pre = int(np.ceil(hd))
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
@@ -110,7 +110,12 @@ class LDSR:
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||
|
||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
@@ -120,6 +125,9 @@ class LDSR:
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
# remove padding
|
||||
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
@@ -16,6 +17,11 @@ def list_localizations(dirname):
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
from modules import scripts
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
|
||||
@@ -23,11 +23,18 @@ def encode(*args):
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename):
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
@@ -78,6 +85,7 @@ def check_pt(filename):
|
||||
|
||||
with z.open('archive/data.pkl') as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
@@ -85,16 +93,42 @@ def check_pt(filename):
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this functon is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename)
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[]
|
||||
callbacks_cfg_denoiser=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +68,7 @@ def clear_callbacks():
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
@@ -89,6 +97,14 @@ def ui_tabs_callback():
|
||||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
|
||||
@@ -3,7 +3,6 @@ import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import modules.ui as ui
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
|
||||
@@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
|
||||
@@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
|
||||
deepbooru.release_process()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
return os.listdir(dirname)
|
||||
|
||||
|
||||
class PreprocessParams:
|
||||
src = None
|
||||
dstdir = None
|
||||
subindex = 0
|
||||
flip = False
|
||||
process_caption = False
|
||||
process_caption_deepbooru = False
|
||||
preprocess_txt_action = None
|
||||
|
||||
|
||||
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
|
||||
caption = ""
|
||||
|
||||
if params.process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if params.process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.get_tags_from_process(image)
|
||||
|
||||
filename_part = params.src
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{params.subindex}-{filename_part}"
|
||||
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
||||
|
||||
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif params.preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
params.subindex += 1
|
||||
|
||||
|
||||
def save_pic(image, index, params, existing_caption=None):
|
||||
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
|
||||
|
||||
if params.flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
|
||||
|
||||
|
||||
def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
||||
if inverse_xy:
|
||||
from_w, from_h = image.height, image.width
|
||||
to_w, to_h = height, width
|
||||
else:
|
||||
from_w, from_h = image.width, image.height
|
||||
to_w, to_h = width, height
|
||||
h = from_h * to_w // from_w
|
||||
if inverse_xy:
|
||||
image = image.resize((h, to_w))
|
||||
else:
|
||||
image = image.resize((to_w, h))
|
||||
|
||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||
y_step = (h - to_h) / (split_count - 1)
|
||||
for i in range(split_count):
|
||||
y = int(y_step * i)
|
||||
if inverse_xy:
|
||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
||||
else:
|
||||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
|
||||
width = process_width
|
||||
@@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
|
||||
files = os.listdir(src)
|
||||
files = listfiles(src)
|
||||
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
def save_pic_with_caption(image, index, existing_caption=None):
|
||||
caption = ""
|
||||
|
||||
if process_caption:
|
||||
caption += shared.interrogator.generate_caption(image)
|
||||
|
||||
if process_caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.get_tags_from_process(image)
|
||||
|
||||
filename_part = filename
|
||||
filename_part = os.path.splitext(filename_part)[0]
|
||||
filename_part = os.path.basename(filename_part)
|
||||
|
||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
subindex[0] += 1
|
||||
|
||||
def save_pic(image, index, existing_caption=None):
|
||||
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||
|
||||
if process_flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||
|
||||
def split_pic(image, inverse_xy):
|
||||
if inverse_xy:
|
||||
from_w, from_h = image.height, image.width
|
||||
to_w, to_h = height, width
|
||||
else:
|
||||
from_w, from_h = image.width, image.height
|
||||
to_w, to_h = width, height
|
||||
h = from_h * to_w // from_w
|
||||
if inverse_xy:
|
||||
image = image.resize((h, to_w))
|
||||
else:
|
||||
image = image.resize((to_w, h))
|
||||
|
||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||
y_step = (h - to_h) / (split_count - 1)
|
||||
for i in range(split_count):
|
||||
y = int(y_step * i)
|
||||
if inverse_xy:
|
||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
||||
else:
|
||||
splitted = image.crop((0, y, to_w, y + to_h))
|
||||
yield splitted
|
||||
|
||||
params = PreprocessParams()
|
||||
params.dstdir = dst
|
||||
params.flip = process_flip
|
||||
params.process_caption = process_caption
|
||||
params.process_caption_deepbooru = process_caption_deepbooru
|
||||
params.preprocess_txt_action = preprocess_txt_action
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
subindex = [0]
|
||||
params.subindex = 0
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
params.src = filename
|
||||
|
||||
existing_caption = None
|
||||
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
||||
if os.path.exists(existing_caption_filename):
|
||||
@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
process_default_resize = True
|
||||
|
||||
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
||||
for splitted in split_pic(img, inverse_xy):
|
||||
save_pic(splitted, index, existing_caption=existing_caption)
|
||||
for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
|
||||
save_pic(splitted, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_focal_crop and img.height != img.width:
|
||||
@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||
dnn_model_path = dnn_model_path,
|
||||
)
|
||||
for focal in autocrop.crop_image(img, autocrop_settings):
|
||||
save_pic(focal, index, existing_caption=existing_caption)
|
||||
save_pic(focal, index, params, existing_caption=existing_caption)
|
||||
process_default_resize = False
|
||||
|
||||
if process_default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index, existing_caption=existing_caption)
|
||||
save_pic(img, index, params, existing_caption=existing_caption)
|
||||
|
||||
shared.state.nextjob()
|
||||
shared.state.nextjob()
|
||||
|
||||
@@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None):
|
||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if (elapsed_m > 0):
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
@@ -1138,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
|
||||
with gr.Blocks() as modelmerger_interface:
|
||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||
@@ -1158,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
with gr.Blocks() as train_interface:
|
||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||
|
||||
@@ -1267,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||
|
||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||
|
||||
script_callbacks.ui_train_tabs_callback(params)
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
@@ -1417,15 +1424,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
if info.refresh is not None:
|
||||
if is_quicksettings:
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
with gr.Row(variant="compact"):
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||
else:
|
||||
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||
|
||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||
|
||||
return res
|
||||
|
||||
@@ -1436,7 +1442,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
opts.reorder()
|
||||
|
||||
def run_settings(*args):
|
||||
changed = 0
|
||||
changed = []
|
||||
|
||||
for key, value, comp in zip(opts.data_labels.keys(), args, components):
|
||||
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
||||
@@ -1454,12 +1460,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
if opts.data_labels[key].onchange is not None:
|
||||
opts.data_labels[key].onchange()
|
||||
|
||||
changed += 1
|
||||
changed.append(key)
|
||||
try:
|
||||
opts.save(shared.config_filename)
|
||||
except RuntimeError:
|
||||
return opts.dumpjson(), f'{changed} settings changed without save.'
|
||||
return opts.dumpjson(), f'{changed} settings changed.'
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
||||
return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
|
||||
|
||||
def run_settings_single(value, key):
|
||||
if not opts.same_type(value, opts.data_labels[key].default):
|
||||
@@ -1563,11 +1569,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
shared.state.need_restart = True
|
||||
|
||||
restart_gradio.click(
|
||||
|
||||
fn=request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
_js='restart_reload'
|
||||
)
|
||||
|
||||
if column is not None:
|
||||
@@ -1637,6 +1642,17 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[component, text_settings],
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
|
||||
|
||||
def get_settings_values():
|
||||
return [getattr(opts, key) for key in component_keys]
|
||||
|
||||
demo.load(
|
||||
fn=get_settings_values,
|
||||
inputs=[],
|
||||
outputs=[component_dict[k] for k in component_keys],
|
||||
)
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = modules.extras.run_modelmerger(*args)
|
||||
@@ -1740,7 +1756,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
return demo
|
||||
|
||||
|
||||
def load_javascript(raw_response):
|
||||
def reload_javascript():
|
||||
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
||||
javascript = f'<script>{jsfile.read()}</script>'
|
||||
|
||||
@@ -1756,7 +1772,7 @@ def load_javascript(raw_response):
|
||||
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
||||
|
||||
def template_response(*args, **kwargs):
|
||||
res = raw_response(*args, **kwargs)
|
||||
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
||||
res.body = res.body.replace(
|
||||
b'</head>', f'{javascript}</head>'.encode("utf8"))
|
||||
res.init_headers()
|
||||
@@ -1765,4 +1781,5 @@ def load_javascript(raw_response):
|
||||
gradio.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
||||
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
|
||||
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
||||
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
|
||||
|
||||
@@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url):
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url):
|
||||
def install_extension_from_index(url, hide_tags):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
return refresh_available_extensions_from_data(), ext_table, message
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return code, ext_table, message
|
||||
|
||||
|
||||
def refresh_available_extensions(url):
|
||||
def refresh_available_extensions(url, hide_tags):
|
||||
global available_extensions
|
||||
|
||||
import urllib.request
|
||||
@@ -155,13 +157,25 @@ def refresh_available_extensions(url):
|
||||
|
||||
available_extensions = json.loads(text)
|
||||
|
||||
return url, refresh_available_extensions_from_data(), ''
|
||||
code, tags = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return url, code, gr.CheckboxGroup.update(choices=tags), ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data():
|
||||
def refresh_available_extensions_for_tags(hide_tags):
|
||||
code, _ = refresh_available_extensions_from_data(hide_tags)
|
||||
|
||||
return code, ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data(hide_tags):
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
tags = available_extensions.get("tags", {})
|
||||
tags_to_hide = set(hide_tags)
|
||||
hidden = 0
|
||||
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="available_extensions">
|
||||
<thead>
|
||||
@@ -178,17 +192,24 @@ def refresh_available_extensions_from_data():
|
||||
name = ext.get("name", "noname")
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
extension_tags = ext.get("tags", [])
|
||||
|
||||
if url is None:
|
||||
continue
|
||||
|
||||
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
|
||||
hidden += 1
|
||||
continue
|
||||
|
||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
|
||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a></td>
|
||||
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
||||
<td>{html.escape(description)}</td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
@@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
if hidden > 0:
|
||||
code += f"<p>Extension hidden: {hidden}</p>"
|
||||
|
||||
return code, list(tags)
|
||||
|
||||
|
||||
def create_ui():
|
||||
@@ -238,21 +262,30 @@ def create_ui():
|
||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index],
|
||||
outputs=[available_extensions_index, available_extensions_table, install_result],
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index, hide_tags],
|
||||
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install],
|
||||
inputs=[extension_to_install, hide_tags],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
hide_tags.change(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
||||
inputs=[hide_tags],
|
||||
outputs=[available_extensions_table, install_result]
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
|
||||
Reference in New Issue
Block a user