mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2026-02-28 02:34:18 +00:00
Merge branch 'AUTOMATIC1111:master' into js
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
import uvicorn
|
||||
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
||||
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import modules.shared as shared
|
||||
from modules import devices
|
||||
@@ -29,6 +31,12 @@ def setUpscalers(req: dict):
|
||||
return reqDict
|
||||
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="png")
|
||||
return base64.b64encode(buffer.getvalue())
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app, queue_lock):
|
||||
self.router = APIRouter()
|
||||
@@ -40,6 +48,7 @@ class Api:
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
@@ -176,6 +185,11 @@ class Api:
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
|
||||
@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
|
||||
info_hash=hash(info),
|
||||
args_hash=hash(upscale_args))
|
||||
args_hash=hash((upscale_args, upscale_first)))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
|
||||
@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
if opts.enable_pnginfo:
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
filename = f"{left}-{n}{right}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
|
||||
@@ -56,9 +56,9 @@ class InterrogateModels:
|
||||
import clip
|
||||
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
||||
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
def first_stage_model_encode_wrap(self, encoder, x):
|
||||
send_me_to_gpu(self, None)
|
||||
return encoder(x)
|
||||
|
||||
def first_stage_model_decode_wrap(self, decoder, z):
|
||||
send_me_to_gpu(self, None)
|
||||
return decoder(z)
|
||||
first_stage_model = sd_model.first_stage_model
|
||||
first_stage_model_encode = sd_model.first_stage_model.encode
|
||||
first_stage_model_decode = sd_model.first_stage_model.decode
|
||||
|
||||
def first_stage_model_encode_wrap(x):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_encode(x)
|
||||
|
||||
def first_stage_model_decode_wrap(z):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# register hooks for those the first two models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
|
||||
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
p.sd_model = None
|
||||
p.sampler = None
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
|
||||
@@ -3,6 +3,8 @@ import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
@@ -25,6 +27,7 @@ class ImageSaveParams:
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callbacks_app_started = []
|
||||
callbacks_model_loaded = []
|
||||
callbacks_ui_tabs = []
|
||||
callbacks_ui_settings = []
|
||||
@@ -40,6 +43,14 @@ def clear_callbacks():
|
||||
callbacks_image_saved.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Blocks, app: FastAPI):
|
||||
for c in callbacks_app_started:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callbacks_model_loaded:
|
||||
try:
|
||||
@@ -69,7 +80,7 @@ def ui_settings_callback():
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callbacks_image_saved:
|
||||
for c in callbacks_before_image_saved:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
@@ -91,6 +102,12 @@ def add_callback(callbacks, fun):
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callbacks_app_started, callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
|
||||
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
|
||||
self.layers = None
|
||||
self.circular_enabled = False
|
||||
self.clip = None
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
|
||||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
|
||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
|
||||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
except Exception:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = steps or p.steps
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
except Exception:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||
@@ -288,11 +289,12 @@ options_templates.update(options_section(('system', "System"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
|
||||
@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||
unload = shared.opts.unload_models_when_training
|
||||
|
||||
if save_embedding_every > 0:
|
||||
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||
@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{embedding_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
do_not_save_grid=True,
|
||||
@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
|
||||
if unload:
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
shared.state.current_image = image
|
||||
|
||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||
@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
|
||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
||||
@@ -25,8 +25,10 @@ def train_embedding(*args):
|
||||
|
||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
try:
|
||||
sd_hijack.undo_optimizations()
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
|
||||
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||
|
||||
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.apply_optimizations()
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ import html
|
||||
from modules import extensions, shared, paths
|
||||
|
||||
|
||||
available_extensions = {"extensions": []}
|
||||
|
||||
|
||||
def check_access():
|
||||
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
|
||||
|
||||
@@ -96,6 +99,14 @@ def extension_table():
|
||||
return code
|
||||
|
||||
|
||||
def normalize_git_url(url):
|
||||
if url is None:
|
||||
return ""
|
||||
|
||||
url = url.replace(".git", "")
|
||||
return url
|
||||
|
||||
|
||||
def install_extension_from_url(dirname, url):
|
||||
check_access()
|
||||
|
||||
@@ -103,14 +114,15 @@ def install_extension_from_url(dirname, url):
|
||||
|
||||
if dirname is None or dirname == "":
|
||||
*parts, last_part = url.split('/')
|
||||
last_part = last_part.replace(".git", "")
|
||||
last_part = normalize_git_url(last_part)
|
||||
|
||||
dirname = last_part
|
||||
|
||||
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||
|
||||
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed'
|
||||
normalized_url = normalize_git_url(url)
|
||||
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
|
||||
|
||||
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
|
||||
|
||||
@@ -128,18 +140,80 @@ def install_extension_from_url(dirname, url):
|
||||
shutil.rmtree(tmpdir, True)
|
||||
|
||||
|
||||
def install_extension_from_index(url):
|
||||
ext_table, message = install_extension_from_url(None, url)
|
||||
|
||||
return refresh_available_extensions_from_data(), ext_table, message
|
||||
|
||||
|
||||
def refresh_available_extensions(url):
|
||||
global available_extensions
|
||||
|
||||
import urllib.request
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text = response.read()
|
||||
|
||||
available_extensions = json.loads(text)
|
||||
|
||||
return url, refresh_available_extensions_from_data(), ''
|
||||
|
||||
|
||||
def refresh_available_extensions_from_data():
|
||||
extlist = available_extensions["extensions"]
|
||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
||||
|
||||
code = f"""<!-- {time.time()} -->
|
||||
<table id="available_extensions">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Extension</th>
|
||||
<th>Description</th>
|
||||
<th>Action</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
"""
|
||||
|
||||
for ext in extlist:
|
||||
name = ext.get("name", "noname")
|
||||
url = ext.get("url", None)
|
||||
description = ext.get("description", "")
|
||||
|
||||
if url is None:
|
||||
continue
|
||||
|
||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||
|
||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
|
||||
<td>{html.escape(description)}</td>
|
||||
<td>{install_code}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
||||
code += """
|
||||
</tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as ui:
|
||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||
with gr.TabItem("Installed"):
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||
check = gr.Button(value="Check for updates")
|
||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
||||
|
||||
extensions_table = gr.HTML(lambda: extension_table())
|
||||
|
||||
@@ -157,16 +231,38 @@ def create_ui():
|
||||
outputs=[extensions_table],
|
||||
)
|
||||
|
||||
with gr.TabItem("Available"):
|
||||
with gr.Row():
|
||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
|
||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||
|
||||
install_result = gr.HTML()
|
||||
available_extensions_table = gr.HTML()
|
||||
|
||||
refresh_available_extensions_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[available_extensions_index],
|
||||
outputs=[available_extensions_index, available_extensions_table, install_result],
|
||||
)
|
||||
|
||||
install_extension_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
||||
inputs=[extension_to_install],
|
||||
outputs=[available_extensions_table, extensions_table, install_result],
|
||||
)
|
||||
|
||||
with gr.TabItem("Install from URL"):
|
||||
install_url = gr.Text(label="URL for extension's git repository")
|
||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||
intall_button = gr.Button(value="Install", variant="primary")
|
||||
intall_result = gr.HTML(elem_id="extension_install_result")
|
||||
install_button = gr.Button(value="Install", variant="primary")
|
||||
install_result = gr.HTML(elem_id="extension_install_result")
|
||||
|
||||
intall_button.click(
|
||||
install_button.click(
|
||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||
inputs=[install_dirname, install_url],
|
||||
outputs=[extensions_table, intall_result],
|
||||
outputs=[extensions_table, install_result],
|
||||
)
|
||||
|
||||
return ui
|
||||
|
||||
Reference in New Issue
Block a user