mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-24 00:09:11 +00:00
Merge branch 'main' into upt
This commit is contained in:
@@ -2,11 +2,13 @@ import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import itertools
|
||||
import datetime
|
||||
import uvicorn
|
||||
import ipaddress
|
||||
import requests
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
@@ -103,6 +105,8 @@ def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
@@ -480,7 +484,11 @@ class Api:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
b64images = [
|
||||
encode_pil_to_base64(image)
|
||||
for image in itertools.chain(processed.images, processed.extra_images)
|
||||
if send_images
|
||||
]
|
||||
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
@@ -547,7 +555,11 @@ class Api:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
b64images = [
|
||||
encode_pil_to_base64(image)
|
||||
for image in itertools.chain(processed.images, processed.extra_images)
|
||||
if send_images
|
||||
]
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
|
||||
@@ -2,8 +2,9 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
from ldm_patched.modules import args_parser
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = args_parser.parser
|
||||
|
||||
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class UpscalerDAT(Upscaler):
|
||||
self.scalers.append(model)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
|
||||
@@ -4,7 +4,10 @@ import re
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
from modules import modelloader, paths, deepbooru_model, images, shared
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
@@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])')
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.patcher = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
@@ -28,16 +39,16 @@ class DeepDanbooru:
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
self.model.eval()
|
||||
self.model.to(devices.cpu, devices.dtype)
|
||||
self.model.to(self.offload_device, self.dtype)
|
||||
|
||||
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
self.model.to(devices.device)
|
||||
model_management.load_models_gpu([self.patcher])
|
||||
|
||||
def stop(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
pass
|
||||
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
@@ -56,8 +67,8 @@ class DeepDanbooru:
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
with torch.no_grad():
|
||||
x = torch.from_numpy(a).to(self.load_device, self.dtype)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
@@ -1,228 +1,93 @@
|
||||
import sys
|
||||
import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors, shared, npu_specific
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
|
||||
if shared.cmd_opts.use_ipex:
|
||||
from modules import xpu_specific
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
|
||||
|
||||
def has_xpu() -> bool:
|
||||
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
|
||||
return model_management.xpu_available
|
||||
|
||||
|
||||
def has_mps() -> bool:
|
||||
if sys.platform != "darwin":
|
||||
return False
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
return model_management.mps_mode()
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
return model_management.get_torch_device().index
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
||||
return "cuda"
|
||||
return str(model_management.get_torch_device())
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
if torch.cuda.is_available():
|
||||
return get_cuda_device_string()
|
||||
|
||||
if has_mps():
|
||||
return "mps"
|
||||
|
||||
if has_xpu():
|
||||
return xpu_specific.get_xpu_device_string()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return npu_specific.get_npu_device_string()
|
||||
|
||||
return "cpu"
|
||||
return model_management.get_torch_device().type
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return torch.device(get_optimal_device_name())
|
||||
return model_management.get_torch_device()
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
|
||||
return cpu
|
||||
|
||||
return get_optimal_device()
|
||||
return model_management.get_torch_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
if has_mps():
|
||||
mac_specific.torch_mps_gc()
|
||||
|
||||
if has_xpu():
|
||||
xpu_specific.torch_xpu_gc()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
torch_npu_set_device()
|
||||
npu_specific.torch_npu_gc()
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||
if npu_specific.has_npu:
|
||||
torch.npu.set_device(0)
|
||||
return
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
return
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if cuda_no_autocast():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
device_esrgan: torch.device = None
|
||||
device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
dtype_inference: torch.dtype = torch.float16
|
||||
device: torch.device = model_management.get_torch_device()
|
||||
device_interrogate: torch.device = cpu # not used
|
||||
device_gfpgan: torch.device = cpu
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
|
||||
device_codeformer: torch.device = cpu
|
||||
dtype: torch.dtype = model_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = model_management.unet_dtype()
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
def cond_cast_unet(input):
|
||||
return input.to(dtype_unet) if unet_needs_upcast else input
|
||||
return input
|
||||
|
||||
|
||||
def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
return input
|
||||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
patch_module_list = []
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
return result
|
||||
return forward_wrapper
|
||||
return
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
applied = False
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if applied:
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
module_type.forward = module_type.org_forward
|
||||
delattr(module_type, "org_forward")
|
||||
return
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
@@ -230,42 +95,8 @@ class NansException(Exception):
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
if shared.cmd_opts.disable_nan_check:
|
||||
return
|
||||
|
||||
if not torch.all(torch.isnan(x)).item():
|
||||
return
|
||||
|
||||
if where == "unet":
|
||||
message = "A tensor with all NaNs was produced in Unet."
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
||||
|
||||
elif where == "vae":
|
||||
message = "A tensor with all NaNs was produced in VAE."
|
||||
|
||||
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
|
||||
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
|
||||
else:
|
||||
message = "A tensor with all NaNs was produced."
|
||||
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
return
|
||||
|
||||
|
||||
@lru_cache
|
||||
def first_time_calculation():
|
||||
"""
|
||||
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
||||
spends about 2.7 seconds doing that, at least wih NVidia.
|
||||
"""
|
||||
|
||||
x = torch.zeros((1, 1)).to(device, dtype)
|
||||
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
||||
linear(x)
|
||||
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception:
|
||||
|
||||
@@ -8,6 +8,7 @@ import re
|
||||
from modules import shared, errors, cache, scripts
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
from modules_forge.config import always_disabled_extensions
|
||||
|
||||
|
||||
os.makedirs(extensions_dir, exist_ok=True)
|
||||
@@ -218,7 +219,17 @@ def list_extensions():
|
||||
continue
|
||||
|
||||
is_builtin = dirname == extensions_builtin_dir
|
||||
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
||||
|
||||
disabled_extensions = shared.opts.disabled_extensions + always_disabled_extensions
|
||||
|
||||
extension = Extension(
|
||||
name=extension_dirname,
|
||||
path=path,
|
||||
enabled=extension_dirname not in disabled_extensions,
|
||||
is_builtin=is_builtin,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
extensions.append(extension)
|
||||
loaded_extensions[canonical_name] = extension
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ class UpscalerHAT(Upscaler):
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
prepare_free_memory()
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
|
||||
@@ -15,6 +15,7 @@ import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.scripts
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
@@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
return batch_results
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@@ -243,4 +244,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args)
|
||||
|
||||
@@ -138,6 +138,8 @@ def register_paste_params_button(binding: ParamBinding):
|
||||
|
||||
def connect_paste_params_buttons():
|
||||
for binding in registered_param_bindings:
|
||||
if binding.tabname not in paste_fields:
|
||||
continue
|
||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||
fields = paste_fields[binding.tabname]["fields"]
|
||||
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
||||
|
||||
@@ -3,11 +3,23 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
from threading import Thread
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
class HiddenPrints:
|
||||
def __enter__(self):
|
||||
self._original_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._original_stdout
|
||||
|
||||
|
||||
def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@@ -23,14 +35,16 @@ def imports():
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
with HiddenPrints():
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import ldm")
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
import ldm.modules.diffusionmodules.model
|
||||
startup_timer.record("import ldm")
|
||||
|
||||
import sgm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import sgm")
|
||||
import sgm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import sgm")
|
||||
|
||||
from modules import shared_init
|
||||
shared_init.initialize()
|
||||
@@ -135,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Accesses shared.sd_model property to load model.
|
||||
After it's available, if it has been loaded before this access by some extension,
|
||||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
from modules import devices
|
||||
devices.torch_npu_set_device()
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
devices.first_time_calculation()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
from modules_forge import main_thread
|
||||
import modules.sd_models
|
||||
main_thread.async_run(modules.sd_models.model_data.get_sd_model)
|
||||
|
||||
from modules import shared_items
|
||||
shared_items.reload_hypernetworks()
|
||||
|
||||
@@ -170,10 +170,11 @@ def configure_sigint_handler():
|
||||
def configure_opts_onchange():
|
||||
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules_forge import main_thread
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
|
||||
@@ -10,7 +10,10 @@ import torch.hub
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
||||
from modules import devices, paths, shared, modelloader, errors
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
@@ -53,7 +56,16 @@ class InterrogateModels:
|
||||
self.loaded_categories = None
|
||||
self.skip_categories = []
|
||||
self.content_dir = content_dir
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.blip_patcher = None
|
||||
self.clip_patcher = None
|
||||
|
||||
def categories(self):
|
||||
if not os.path.exists(self.content_dir):
|
||||
@@ -105,49 +117,37 @@ class InterrogateModels:
|
||||
|
||||
def load_clip_model(self):
|
||||
import clip
|
||||
import clip.model
|
||||
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
||||
clip.model.LayerNorm = torch.nn.LayerNorm
|
||||
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
return model, preprocess
|
||||
|
||||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
self.blip_model = self.blip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||
self.blip_patcher = ModelPatcher(self.blip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.clip_model = self.clip_model.half()
|
||||
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
||||
model_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
||||
return
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
if self.clip_model is not None:
|
||||
self.clip_model = self.clip_model.to(devices.cpu)
|
||||
pass
|
||||
|
||||
def send_blip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
if self.blip_model is not None:
|
||||
self.blip_model = self.blip_model.to(devices.cpu)
|
||||
pass
|
||||
|
||||
def unload(self):
|
||||
self.send_clip_to_ram()
|
||||
self.send_blip_to_ram()
|
||||
|
||||
devices.torch_gc()
|
||||
pass
|
||||
|
||||
def rank(self, image_features, text_array, top_count=1):
|
||||
import clip
|
||||
@@ -158,11 +158,11 @@ class InterrogateModels:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||
text_tokens = clip.tokenize(list(text_array), truncate=True).to(self.load_device)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
||||
similarity = torch.zeros((1, len(text_array))).to(self.load_device)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
@@ -175,7 +175,7 @@ class InterrogateModels:
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
||||
@@ -186,9 +186,6 @@ class InterrogateModels:
|
||||
res = ""
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
||||
self.load()
|
||||
|
||||
caption = self.generate_caption(pil_image)
|
||||
@@ -197,7 +194,7 @@ class InterrogateModels:
|
||||
|
||||
res = caption
|
||||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||
|
||||
@@ -12,9 +12,12 @@ import json
|
||||
from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules.paths_internal import script_path, extensions_dir, extensions_builtin_dir
|
||||
from modules.timer import startup_timer
|
||||
from modules import logging_config
|
||||
from modules_forge import forge_version
|
||||
from modules_forge.config import always_disabled_extensions
|
||||
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
logging_config.setup_logging(args.loglevel)
|
||||
@@ -70,7 +73,7 @@ def commit_hash():
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def git_tag():
|
||||
def git_tag_a1111():
|
||||
try:
|
||||
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
@@ -85,6 +88,10 @@ def git_tag():
|
||||
return "<none>"
|
||||
|
||||
|
||||
def git_tag():
|
||||
return 'f' + forge_version.version + '-' + git_tag_a1111()
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
@@ -252,7 +259,7 @@ def list_extensions(settings_file):
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []) + always_disabled_extensions)
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||
@@ -261,6 +268,27 @@ def list_extensions(settings_file):
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def list_extensions_builtin(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_builtin_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_builtin_dir) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
@@ -275,6 +303,21 @@ def run_extensions_installers(settings_file):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
if not os.path.isdir(extensions_builtin_dir):
|
||||
return
|
||||
|
||||
with startup_timer.subcategory("run extensions_builtin installers"):
|
||||
for dirname_extension in list_extensions_builtin(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_builtin_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
return
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
|
||||
@@ -468,6 +511,11 @@ def start():
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
from modules_forge import main_thread
|
||||
|
||||
main_thread.loop()
|
||||
return
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
|
||||
@@ -6,142 +6,20 @@ cpu = torch.device("cpu")
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
global module_in_gpu
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module_in_gpu = None
|
||||
return
|
||||
|
||||
|
||||
def is_needed(sd_model):
|
||||
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
||||
return False
|
||||
|
||||
|
||||
def apply(sd_model):
|
||||
enable = is_needed(sd_model)
|
||||
shared.parallel_processing_allowed = not enable
|
||||
|
||||
if enable:
|
||||
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
||||
else:
|
||||
sd_model.lowvram = False
|
||||
return
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
if getattr(sd_model, 'lowvram', False):
|
||||
return
|
||||
|
||||
sd_model.lowvram = True
|
||||
|
||||
parents = {}
|
||||
|
||||
def send_me_to_gpu(module, _):
|
||||
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
||||
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
||||
be in CPU
|
||||
"""
|
||||
global module_in_gpu
|
||||
|
||||
module = parents.get(module, module)
|
||||
|
||||
if module_in_gpu == module:
|
||||
return
|
||||
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module.to(devices.device)
|
||||
module_in_gpu = module
|
||||
|
||||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
|
||||
first_stage_model = sd_model.first_stage_model
|
||||
first_stage_model_encode = sd_model.first_stage_model.encode
|
||||
first_stage_model_decode = sd_model.first_stage_model.decode
|
||||
|
||||
def first_stage_model_encode_wrap(x):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_encode(x)
|
||||
|
||||
def first_stage_model_decode_wrap(z):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
to_remain_in_cpu = [
|
||||
(sd_model, 'first_stage_model'),
|
||||
(sd_model, 'depth_model'),
|
||||
(sd_model, 'embedder'),
|
||||
(sd_model, 'model'),
|
||||
(sd_model, 'embedder'),
|
||||
]
|
||||
|
||||
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||
|
||||
if is_sdxl:
|
||||
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||
elif is_sd2:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||
else:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
||||
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
||||
stored = []
|
||||
for obj, field in to_remain_in_cpu:
|
||||
module = getattr(obj, field, None)
|
||||
stored.append(module)
|
||||
setattr(obj, field, None)
|
||||
|
||||
# send the model to GPU.
|
||||
sd_model.to(devices.device)
|
||||
|
||||
# put modules back. the modules will be in CPU.
|
||||
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
||||
setattr(obj, field, module)
|
||||
|
||||
# register hooks for those the first three models
|
||||
if is_sdxl:
|
||||
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||
elif is_sd2:
|
||||
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
||||
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
||||
else:
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
if sd_model.depth_model:
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
if sd_model.embedder:
|
||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
diff_model = sd_model.model.diffusion_model
|
||||
|
||||
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
|
||||
# so that only one of them is in GPU at a time
|
||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||
sd_model.model.to(devices.device)
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.output_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
return
|
||||
|
||||
|
||||
def is_enabled(sd_model):
|
||||
return sd_model.lowvram
|
||||
return False
|
||||
|
||||
@@ -445,7 +445,7 @@ class UniPC:
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
return x0.to(x)
|
||||
|
||||
def model_fn(self, x, t):
|
||||
"""
|
||||
|
||||
@@ -62,3 +62,13 @@ for d, must_exist, what, options in path_dirs:
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
import ldm_patched.utils.path_utils as ldm_patched_path_utils
|
||||
|
||||
ldm_patched_path_utils.base_path = data_path
|
||||
ldm_patched_path_utils.models_dir = models_path
|
||||
ldm_patched_path_utils.output_directory = os.path.join(data_path, "output")
|
||||
ldm_patched_path_utils.temp_directory = os.path.join(data_path, "temp")
|
||||
ldm_patched_path_utils.input_directory = os.path.join(data_path, "input")
|
||||
ldm_patched_path_utils.user_directory = os.path.join(data_path, "user")
|
||||
|
||||
@@ -256,6 +256,9 @@ class StableDiffusionProcessing:
|
||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||
self.cached_c = StableDiffusionProcessing.cached_c
|
||||
|
||||
self.extra_result_images = []
|
||||
self.modified_noise = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
@@ -515,8 +518,9 @@ class StableDiffusionProcessing:
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments="", extra_images_list=[]):
|
||||
self.images = images_list
|
||||
self.extra_images = extra_images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
self.seed = seed
|
||||
@@ -628,44 +632,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
|
||||
for i in range(batch.shape[0]):
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if check_for_nans:
|
||||
|
||||
try:
|
||||
devices.test_for_nans(sample, "vae")
|
||||
except devices.NansException as e:
|
||||
if shared.opts.auto_vae_precision_bfloat16:
|
||||
autofix_dtype = torch.bfloat16
|
||||
autofix_dtype_text = "bfloat16"
|
||||
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
|
||||
autofix_dtype_comment = ""
|
||||
elif shared.opts.auto_vae_precision:
|
||||
autofix_dtype = torch.float32
|
||||
autofix_dtype_text = "32-bit float"
|
||||
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
|
||||
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
else:
|
||||
raise e
|
||||
|
||||
if devices.dtype_vae == autofix_dtype:
|
||||
raise e
|
||||
|
||||
errors.print_error_explanation(
|
||||
"A tensor with all NaNs was produced in VAE.\n"
|
||||
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
|
||||
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
|
||||
)
|
||||
|
||||
devices.dtype_vae = autofix_dtype
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
batch = batch.to(devices.dtype_vae)
|
||||
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if target_device is not None:
|
||||
sample = sample.to(target_device)
|
||||
|
||||
samples.append(sample)
|
||||
samples.append(sample.to(target_device))
|
||||
|
||||
return samples
|
||||
|
||||
@@ -848,7 +815,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with torch.inference_mode():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
@@ -872,6 +839,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||
|
||||
p.sd_model.forge_objects = p.sd_model.forge_objects_original.shallow_copy()
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
@@ -888,8 +856,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.parse_extra_network_prompts()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
extra_networks.activate(p, p.extra_network_data)
|
||||
extra_networks.activate(p, p.extra_network_data)
|
||||
|
||||
p.sd_model.forge_objects = p.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
@@ -941,8 +910,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', [])
|
||||
alphas_cumprod_backup = None
|
||||
|
||||
if len(alphas_cumprod_modifiers) > 0:
|
||||
alphas_cumprod_backup = p.sd_model.alphas_cumprod
|
||||
for modifier in alphas_cumprod_modifiers:
|
||||
p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod)
|
||||
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if alphas_cumprod_backup is not None:
|
||||
p.sd_model.alphas_cumprod = alphas_cumprod_backup
|
||||
|
||||
if p.scripts is not None:
|
||||
ps = scripts.PostSampleArgs(samples_ddim)
|
||||
@@ -961,9 +940,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
del samples_ddim
|
||||
|
||||
if lowvram.is_enabled(shared.sd_model):
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
state.nextjob()
|
||||
@@ -1102,6 +1078,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
subseed=p.all_subseeds[0],
|
||||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
extra_images_list=p.extra_result_images,
|
||||
)
|
||||
|
||||
if p.scripts is not None:
|
||||
@@ -1270,7 +1247,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
image = torch.from_numpy(np.expand_dims(image, axis=0))
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
@@ -1283,6 +1260,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
# here we generate an image normally
|
||||
|
||||
x = self.rng.next()
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=x,
|
||||
noise=x,
|
||||
c=conditioning,
|
||||
uc=unconditional_conditioning)
|
||||
|
||||
if self.modified_noise is not None:
|
||||
x = self.modified_noise
|
||||
self.modified_noise = None
|
||||
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
@@ -1354,7 +1344,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
batch_images.append(image)
|
||||
|
||||
decoded_samples = torch.from_numpy(np.array(batch_images))
|
||||
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
|
||||
decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
@@ -1384,6 +1374,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=samples,
|
||||
noise=noise,
|
||||
c=self.hr_c,
|
||||
uc=self.hr_uc)
|
||||
|
||||
if self.modified_noise is not None:
|
||||
noise = self.modified_noise
|
||||
self.modified_noise = None
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
@@ -1459,7 +1461,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if shared.opts.hires_fix_use_firstpass_conds:
|
||||
self.calculate_hr_conds()
|
||||
|
||||
elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
||||
elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
||||
with devices.autocast():
|
||||
extra_networks.activate(self, self.hr_extra_network_data)
|
||||
|
||||
@@ -1646,7 +1648,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||
|
||||
image = torch.from_numpy(batch_images)
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
@@ -1687,6 +1689,18 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||
x *= self.initial_noise_multiplier
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=self.init_latent,
|
||||
noise=x,
|
||||
c=conditioning,
|
||||
uc=unconditional_conditioning)
|
||||
|
||||
if self.modified_noise is not None:
|
||||
x = self.modified_noise
|
||||
self.modified_noise = None
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
|
||||
@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape):
|
||||
def __init__(self, x):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@@ -276,6 +276,19 @@ class DictWithShape(dict):
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
self[k] = self[k].to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def advanced_indexing(self, item):
|
||||
result = {}
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
result[k] = self[k][item]
|
||||
return DictWithShape(result)
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
@@ -284,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||
res = DictWithShape(res)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
@@ -342,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
stacked = DictWithShape(stacked)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.scalers.append(scaler)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
prepare_free_memory()
|
||||
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
|
||||
@@ -8,7 +8,13 @@ def randn(seed, shape, generator=None):
|
||||
|
||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||
|
||||
manual_seed(seed)
|
||||
if generator is not None:
|
||||
# If generator is not none, we must use another seed to
|
||||
# avoid global torch.rand to get same noise again.
|
||||
# Note: removing this will make DDPM sampler broken.
|
||||
manual_seed((seed + 100000) % 65536)
|
||||
else:
|
||||
manual_seed(seed)
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||
|
||||
@@ -192,5 +192,4 @@ with safe.Extra(handler):
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
global_extra_handler = None
|
||||
|
||||
@@ -140,6 +140,9 @@ callback_map = dict(
|
||||
callbacks_list_unets=[],
|
||||
callbacks_before_token_counter=[],
|
||||
)
|
||||
event_subscriber_map = dict(
|
||||
callbacks_setting_updated=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
@@ -328,6 +331,23 @@ def before_token_counter_callback(params: BeforeTokenCounterParams):
|
||||
report_exception(c, 'before_token_counter')
|
||||
|
||||
|
||||
def setting_updated_event_subscriber_chain(handler, component, setting_name: str):
|
||||
"""
|
||||
Arguments:
|
||||
- handler: The returned handler from calling an event subscriber.
|
||||
- component: The component that is updated. The component should provide
|
||||
the value of setting after update.
|
||||
- setting_name: The name of the setting.
|
||||
"""
|
||||
for param in event_subscriber_map['callbacks_setting_updated']:
|
||||
handler = handler.then(
|
||||
fn=lambda *args: param["fn"](*args, setting_name),
|
||||
inputs=param["inputs"] + [component],
|
||||
outputs=param["outputs"],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
@@ -509,3 +529,14 @@ def on_before_token_counter(callback):
|
||||
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_token_counter'], callback)
|
||||
|
||||
|
||||
def on_setting_updated_subscriber(subscriber_params):
|
||||
"""register a function to be called after settings update. `subscriber_params`
|
||||
should contain necessary fields to register an gradio event handler. Necessary
|
||||
fields are ["fn", "outputs", "inputs"].
|
||||
Setting name and setting value after update will be append to inputs. So be
|
||||
sure to handle these extra params when defining the callback function.
|
||||
"""
|
||||
event_subscriber_map['callbacks_setting_updated'].append(subscriber_params)
|
||||
|
||||
|
||||
@@ -186,6 +186,14 @@ class Script:
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
"""
|
||||
Similar to process(), called before every sampling.
|
||||
If you use high-res fix, this will be called two times.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
@@ -504,8 +512,14 @@ def load_scripts():
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||
|
||||
for s in scripts_list:
|
||||
if s.basedir not in sys.path:
|
||||
sys.path = [s.basedir] + sys.path
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
# print(f'Current System Paths = {syspath}')
|
||||
|
||||
def register_scripts_from_module(module):
|
||||
for script_class in module.__dict__.values():
|
||||
if not inspect.isclass(script_class):
|
||||
@@ -809,6 +823,14 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
||||
|
||||
def process_before_every_sampling(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process_before_every_sampling(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
||||
@@ -57,57 +57,11 @@ def list_optimizers():
|
||||
|
||||
|
||||
def apply_optimizations(option=None):
|
||||
global current_optimizer
|
||||
|
||||
undo_optimizations()
|
||||
|
||||
if len(optimizers) == 0:
|
||||
# a script can access the model very early, and optimizations would not be filled by then
|
||||
current_optimizer = None
|
||||
return ''
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
if current_optimizer is not None:
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
|
||||
selection = option or shared.opts.cross_attention_optimization
|
||||
if selection == "Automatic" and len(optimizers) > 0:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||
else:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
|
||||
|
||||
if selection == "None":
|
||||
matching_optimizer = None
|
||||
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
|
||||
matching_optimizer = None
|
||||
elif matching_optimizer is None:
|
||||
matching_optimizer = optimizers[0]
|
||||
|
||||
if matching_optimizer is not None:
|
||||
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
||||
matching_optimizer.apply()
|
||||
print("done.")
|
||||
current_optimizer = matching_optimizer
|
||||
return current_optimizer.name
|
||||
else:
|
||||
print("Disabling attention optimization")
|
||||
return ''
|
||||
return
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
return
|
||||
|
||||
|
||||
def fix_checkpoint():
|
||||
@@ -182,156 +136,30 @@ class StableDiffusionModelHijack:
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
try:
|
||||
self.optimization_method = apply_optimizations(option)
|
||||
except Exception as e:
|
||||
errors.display(e, "applying cross attention optimization")
|
||||
undo_optimizations()
|
||||
pass
|
||||
|
||||
def convert_sdxl_to_ssd(self, m):
|
||||
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
|
||||
|
||||
delattr(m.model.diffusion_model.middle_block, '1')
|
||||
delattr(m.model.diffusion_model.middle_block, '2')
|
||||
for i in ['9', '8', '7', '6', '5', '4']:
|
||||
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
|
||||
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
|
||||
devices.torch_gc()
|
||||
pass
|
||||
|
||||
def hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
text_cond_models = []
|
||||
|
||||
for i in range(len(conditioner.embedders)):
|
||||
embedder = conditioner.embedders[i]
|
||||
typename = type(embedder).__name__
|
||||
if typename == 'FrozenOpenCLIPEmbedder':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenCLIPEmbedder':
|
||||
model_embeddings = embedder.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
if len(text_cond_models) == 1:
|
||||
m.cond_stage_model = text_cond_models[0]
|
||||
else:
|
||||
m.cond_stage_model = conditioner
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
apply_weighted_forward(m)
|
||||
if m.cond_stage_key == "edit":
|
||||
sd_hijack_unet.hijack_ddpm_edit()
|
||||
|
||||
self.apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
res = [el]
|
||||
for c in flattened:
|
||||
res += c
|
||||
return res
|
||||
|
||||
self.layers = flatten(m)
|
||||
|
||||
import modules.models.diffusion.ddpm_edit
|
||||
|
||||
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
|
||||
sd_unet.original_forward = ldm_original_forward
|
||||
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
|
||||
sd_unet.original_forward = ldm_original_forward
|
||||
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
|
||||
sd_unet.original_forward = sgm_original_forward
|
||||
else:
|
||||
sd_unet.original_forward = None
|
||||
|
||||
pass
|
||||
|
||||
def undo_hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
for i in range(len(conditioner.embedders)):
|
||||
embedder = conditioner.embedders[i]
|
||||
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
|
||||
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
|
||||
conditioner.embedders[i] = embedder.wrapped
|
||||
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
|
||||
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
|
||||
conditioner.embedders[i] = embedder.wrapped
|
||||
|
||||
if hasattr(m, 'cond_stage_model'):
|
||||
delattr(m, 'cond_stage_model')
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
undo_optimizations()
|
||||
undo_weighted_forward(m)
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
||||
pass
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
||||
self.circular_enabled = enable
|
||||
|
||||
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||
pass
|
||||
|
||||
def clear_comments(self):
|
||||
self.comments = []
|
||||
self.extra_generation_params = {}
|
||||
|
||||
def get_prompt_lengths(self, text):
|
||||
if self.clip is None:
|
||||
return "-", "-"
|
||||
|
||||
_, token_count = self.clip.process_texts([text])
|
||||
|
||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||
def get_prompt_lengths(self, text, cond_stage_model):
|
||||
_, token_count = cond_stage_model.process_texts([text])
|
||||
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
|
||||
|
||||
def redo_hijack(self, m):
|
||||
self.undo_hijack(m)
|
||||
self.hijack(m)
|
||||
pass
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
@@ -340,6 +168,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
self.weight = self.wrapped.weight
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
|
||||
@@ -10,13 +10,19 @@ from omegaconf import OmegaConf, ListConfig
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
import numpy as np
|
||||
from modules_forge import forge_loader
|
||||
import modules_forge.ops as forge_ops
|
||||
from ldm_patched.modules.ops import manual_cast
|
||||
from ldm_patched.modules import model_management as model_management
|
||||
import ldm_patched.modules.model_patcher
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@@ -150,9 +156,9 @@ def list_models():
|
||||
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
|
||||
model_url = None
|
||||
else:
|
||||
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
|
||||
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors"
|
||||
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
|
||||
if os.path.exists(cmd_ckpt):
|
||||
checkpoint_info = CheckpointInfo(cmd_ckpt)
|
||||
@@ -366,26 +372,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
if model.is_ssd:
|
||||
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
@@ -395,65 +387,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
del state_dict
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
timer.record("apply float()")
|
||||
else:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
@@ -590,14 +523,6 @@ class SdModelData:
|
||||
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||
|
||||
try:
|
||||
self.loaded_sd_models.remove(v)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if v is not None:
|
||||
self.loaded_sd_models.insert(0, v)
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
@@ -615,31 +540,19 @@ def get_empty_cond(sd_model):
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
if m.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
m.to(devices.cpu)
|
||||
|
||||
devices.torch_gc()
|
||||
pass
|
||||
|
||||
|
||||
def model_target_device(m):
|
||||
if lowvram.is_needed(m):
|
||||
return devices.cpu
|
||||
else:
|
||||
return devices.device
|
||||
return devices.device
|
||||
|
||||
|
||||
def send_model_to_device(m):
|
||||
lowvram.apply(m)
|
||||
|
||||
if not m.lowvram:
|
||||
m.to(shared.device)
|
||||
pass
|
||||
|
||||
|
||||
def send_model_to_trash(m):
|
||||
m.to(device="meta")
|
||||
devices.torch_gc()
|
||||
pass
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
@@ -649,9 +562,14 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
send_model_to_trash(model_data.sd_model)
|
||||
if model_data.sd_model.filename == checkpoint_info.filename:
|
||||
return model_data.sd_model
|
||||
|
||||
model_data.sd_model = None
|
||||
devices.torch_gc()
|
||||
model_data.loaded_sd_models = []
|
||||
model_management.unload_all_models()
|
||||
model_management.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
timer.record("unload existing model")
|
||||
|
||||
@@ -660,58 +578,27 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
timer.record("find config")
|
||||
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
|
||||
sd_model.filename = checkpoint_info.filename
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
del state_dict
|
||||
|
||||
timer.record("load config")
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(sd_model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
weight_dtype_conversion = None
|
||||
else:
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
}
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
@@ -723,7 +610,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
with torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
@@ -734,156 +621,20 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
"""
|
||||
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||
If no such model exists, returns None.
|
||||
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||
"""
|
||||
|
||||
already_loaded = None
|
||||
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||
loaded_model = model_data.loaded_sd_models[i]
|
||||
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
already_loaded = loaded_model
|
||||
continue
|
||||
|
||||
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||
model_data.loaded_sd_models.pop()
|
||||
send_model_to_trash(loaded_model)
|
||||
timer.record("send model to trash")
|
||||
|
||||
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||
send_model_to_cpu(sd_model)
|
||||
timer.record("send model to cpu")
|
||||
|
||||
if already_loaded is not None:
|
||||
send_model_to_device(already_loaded)
|
||||
timer.record("send model to device")
|
||||
|
||||
model_data.set_sd_model(already_loaded, already_loaded=True)
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
|
||||
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
|
||||
|
||||
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||
sd_vae.reload_vae_weights(already_loaded)
|
||||
return model_data.sd_model
|
||||
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||
|
||||
model_data.sd_model = None
|
||||
load_model(checkpoint_info)
|
||||
return model_data.sd_model
|
||||
elif len(model_data.loaded_sd_models) > 0:
|
||||
sd_model = model_data.loaded_sd_models.pop()
|
||||
model_data.sd_model = sd_model
|
||||
|
||||
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
|
||||
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
|
||||
|
||||
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||
return sd_model
|
||||
else:
|
||||
return None
|
||||
pass
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = model_data.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if check_fp8(sd_model) != devices.fp8:
|
||||
# load from state dict again to prevent extra numerical errors
|
||||
forced_reload = True
|
||||
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
sd_unet.apply_unet("None")
|
||||
send_model_to_cpu(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
if sd_model is not None:
|
||||
send_model_to_trash(sd_model)
|
||||
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return model_data.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
sd_unet.apply_unet()
|
||||
|
||||
return sd_model
|
||||
return load_model(info)
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
send_model_to_cpu(sd_model or shared.sd_model)
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
"""
|
||||
Applies speed and memory optimizations from tomesd.
|
||||
"""
|
||||
|
||||
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
||||
|
||||
if current_token_merging_ratio == token_merging_ratio:
|
||||
return
|
||||
|
||||
if current_token_merging_ratio > 0:
|
||||
tomesd.remove_patch(sd_model)
|
||||
|
||||
if token_merging_ratio > 0:
|
||||
tomesd.apply_patch(
|
||||
sd_model,
|
||||
ratio=token_merging_ratio,
|
||||
use_rand=False, # can cause issues with some samplers
|
||||
merge_attn=True,
|
||||
merge_crossattn=False,
|
||||
merge_mlp=False
|
||||
)
|
||||
print('Token merging is under construction now and the setting will not take effect.')
|
||||
|
||||
sd_model.applied_token_merged_ratio = token_merging_ratio
|
||||
# TODO: rework using new UNet patcher system
|
||||
return
|
||||
|
||||
@@ -8,8 +8,13 @@ import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from modules_forge.forge_clip import move_clip_to_gpu
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
move_clip_to_gpu()
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
@@ -18,7 +23,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=model_management.text_encoder_dtype())
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
@@ -34,14 +39,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
return c
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||
sd = self.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
|
||||
if self.model.diffusion_model.in_channels == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
|
||||
return self.model(x, t, cond)
|
||||
return self.model(x, t, cond, *args, **kwargs)
|
||||
|
||||
|
||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
|
||||
@@ -2,11 +2,13 @@ from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_l
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
from modules_forge import forge_alter_samplers
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
*sd_samplers_lcm.samplers_data_lcm,
|
||||
*forge_alter_samplers.samplers_data_alter
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
from modules_forge import forge_sampler
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
@@ -58,15 +59,16 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
|
||||
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||
# as the original latents do not have noise
|
||||
# Backward Compatibility
|
||||
self.mask_before_denoising = False
|
||||
|
||||
self.classic_ddim_eps_estimation = False
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
@@ -152,142 +154,38 @@ class CFGDenoiser(torch.nn.Module):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if sd_samplers_common.apply_refiner(self):
|
||||
original_x_device = x.device
|
||||
original_x_dtype = x.dtype
|
||||
|
||||
if self.classic_ddim_eps_estimation:
|
||||
acd = self.inner_model.inner_model.alphas_cumprod
|
||||
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
||||
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
||||
real_sigma_data = 1.0
|
||||
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
||||
sigma = real_sigma
|
||||
|
||||
if sd_samplers_common.apply_refiner(self, x):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
if self.mask is not None:
|
||||
noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)
|
||||
x = x * self.nmask + noisy_initial_latent * self.mask
|
||||
|
||||
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||
def apply_blend(current_latent):
|
||||
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
|
||||
|
||||
if self.p.scripts is not None:
|
||||
from modules import scripts
|
||||
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
|
||||
self.p.scripts.on_mask_blend(self.p, mba)
|
||||
blended_latent = mba.blended_latent
|
||||
|
||||
return blended_latent
|
||||
|
||||
# Blend in the original latents (before)
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = apply_blend(x)
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
|
||||
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params,
|
||||
cond_scale=cond_scale, cond_composition=cond_composition)
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
self.padded_cond_uncond_v0 = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
|
||||
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
|
||||
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.opts.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
# Blend in the original latents (after)
|
||||
if not self.mask_before_denoising and self.mask is not None:
|
||||
denoised = apply_blend(denoised)
|
||||
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
preview = self.sampler.last_latent
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||
else:
|
||||
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||
if self.mask is not None:
|
||||
denoised = denoised * self.nmask + self.init_latent * self.mask
|
||||
|
||||
preview = self.sampler.last_latent = denoised
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
@@ -295,5 +193,10 @@ class CFGDenoiser(torch.nn.Module):
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
if self.classic_ddim_eps_estimation:
|
||||
eps = (x - denoised) / sigma[:, None, None, None]
|
||||
return eps
|
||||
|
||||
return denoised.to(device=original_x_device, dtype=original_x_dtype)
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
from modules import extra_networks
|
||||
import k_diffusion.sampling
|
||||
|
||||
|
||||
@@ -39,9 +41,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
|
||||
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
from modules import lowvram
|
||||
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
||||
if approximation == 0:
|
||||
approximation = 1
|
||||
|
||||
if approximation == 2:
|
||||
@@ -54,8 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
return x_sample
|
||||
|
||||
@@ -71,7 +70,6 @@ def single_sample_to_image(sample, approximation=None):
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
x = x.to(devices.dtype_vae)
|
||||
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
@@ -95,7 +93,6 @@ def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
@@ -155,7 +152,7 @@ def replace_torchsde_browinan():
|
||||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser):
|
||||
def apply_refiner(cfg_denoiser, x):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
@@ -181,13 +178,18 @@ def apply_refiner(cfg_denoiser):
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
if not cfg_denoiser.p.disable_extra_networks:
|
||||
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
@@ -139,11 +141,18 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
|
||||
x = x.to(noise)
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
if opts.img2img_extra_noise > 0:
|
||||
@@ -189,12 +198,20 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||
|
||||
if opts.sgm_noise_multiplier:
|
||||
p.extra_generation_params["SGM noise multiplier"] = True
|
||||
@@ -235,6 +252,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
@@ -27,14 +27,14 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
start = self.sigma_to_t(self.sigma_max)
|
||||
end = self.sigma_to_t(self.sigma_min)
|
||||
|
||||
t = torch.linspace(start, end, n, device=shared.sd_model.device)
|
||||
t = torch.linspace(start, end, n, device=self.sigmas.device)
|
||||
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
dists = log_sigma - self.log_sigmas.to(sigma)[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
samplers_timesteps = [
|
||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||
@@ -50,10 +52,11 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
super().__init__(sampler)
|
||||
|
||||
self.alphas = shared.sd_model.alphas_cumprod
|
||||
self.mask_before_denoising = True
|
||||
self.classic_ddim_eps_estimation = True
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
ts = sigma.to(dtype=int)
|
||||
self.alphas = self.alphas.to(ts.device)
|
||||
|
||||
a_t = self.alphas[ts][:, None, None, None]
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
@@ -95,16 +98,21 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
timesteps = self.get_timesteps(p, steps).to(x.device)
|
||||
timesteps_sched = timesteps[:t_enc]
|
||||
|
||||
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||
|
||||
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||
xi = x.to(noise) * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||
|
||||
if opts.img2img_extra_noise > 0:
|
||||
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||
@@ -135,11 +143,18 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
|
||||
|
||||
steps = steps or p.steps
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
timesteps = self.get_timesteps(p, steps).to(x.device)
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
@@ -159,6 +174,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
@@ -45,15 +45,8 @@ def apply_unet(option=None):
|
||||
current_unet_option = new_option
|
||||
if current_unet_option is None:
|
||||
current_unet = None
|
||||
|
||||
if not shared.sd_model.lowvram:
|
||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||
|
||||
return
|
||||
|
||||
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
current_unet = current_unet_option.create_unet()
|
||||
current_unet.option = current_unet_option
|
||||
print(f"Activating unet: {current_unet.option.label}")
|
||||
|
||||
@@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
@@ -263,20 +262,12 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if sd_model.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
|
||||
@@ -24,13 +24,6 @@ def initialize():
|
||||
pass
|
||||
|
||||
from modules import devices
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
||||
|
||||
shared.device = devices.device
|
||||
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint", "sd_vae", "CLIP_stop_at_last_layers"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||
|
||||
@@ -2,6 +2,8 @@ import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import torch
|
||||
|
||||
from modules import errors, shared, devices
|
||||
from typing import Optional
|
||||
@@ -134,6 +136,7 @@ class State:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@torch.inference_mode()
|
||||
def set_current_image(self):
|
||||
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||
if not shared.parallel_processing_allowed:
|
||||
@@ -142,6 +145,7 @@ class State:
|
||||
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
@torch.inference_mode()
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
@@ -156,11 +160,14 @@ class State:
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# traceback.print_exc()
|
||||
# print(e)
|
||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||
# we silently ignore this error
|
||||
errors.record_exception()
|
||||
|
||||
@torch.inference_mode()
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
|
||||
@@ -9,6 +9,7 @@ import modules.shared as shared
|
||||
from modules.ui import plaintext_to_html
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||
@@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
||||
return p
|
||||
|
||||
|
||||
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
assert len(gallery) > 0, 'No image to upscale'
|
||||
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
|
||||
|
||||
@@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
|
||||
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def txt2img(id_task: str, request: gr.Request, *args):
|
||||
def txt2img_function(id_task: str, request: gr.Request, *args):
|
||||
p = txt2img_create_processing(id_task, request, *args)
|
||||
|
||||
with closing(p):
|
||||
@@ -118,4 +119,12 @@ def txt2img(id_task: str, request: gr.Request, *args):
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args)
|
||||
|
||||
|
||||
def txt2img(id_task: str, request: gr.Request, *args):
|
||||
return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args)
|
||||
|
||||
@@ -178,9 +178,15 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
|
||||
# messages related to it in console
|
||||
prompt_schedules = [[[steps, text]]]
|
||||
|
||||
try:
|
||||
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
|
||||
assert cond_stage_model is not None
|
||||
except Exception:
|
||||
return f"<span class='gr-box gr-text-input'>?/?</span>"
|
||||
|
||||
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
|
||||
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ Requested path was: {f}
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None, object_fit='contain')
|
||||
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||
@@ -206,7 +206,8 @@ Requested path was: {f}
|
||||
buttons = {
|
||||
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
|
||||
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
|
||||
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
||||
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab."),
|
||||
'svd': ToolButton('🎬', elem_id=f'{tabname}_send_to_svd', tooltip="Send image and generation parameters to SVD tab."),
|
||||
}
|
||||
|
||||
if tabname == 'txt2img':
|
||||
|
||||
@@ -294,7 +294,6 @@ class UiSettings:
|
||||
|
||||
for _i, k, _item in self.quicksettings_list:
|
||||
component = self.component_dict[k]
|
||||
info = opts.data_labels[k]
|
||||
|
||||
if isinstance(component, gr.Textbox):
|
||||
methods = [component.submit, component.blur]
|
||||
@@ -304,20 +303,30 @@ class UiSettings:
|
||||
methods = [component.change]
|
||||
|
||||
for method in methods:
|
||||
method(
|
||||
handler = method(
|
||||
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, self.text_settings],
|
||||
show_progress=info.refresh is not None,
|
||||
show_progress=False,
|
||||
)
|
||||
script_callbacks.setting_updated_event_subscriber_chain(
|
||||
handler=handler,
|
||||
component=component,
|
||||
setting_name=k,
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
handler = button_set_checkpoint.click(
|
||||
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
||||
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
||||
)
|
||||
script_callbacks.setting_updated_event_subscriber_chain(
|
||||
handler=handler,
|
||||
component=self.component_dict['sd_model_checkpoint'],
|
||||
setting_name="sd_model_checkpoint"
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
||||
|
||||
|
||||
@@ -6,6 +6,18 @@ from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def prepare_free_memory(aggressive=False):
|
||||
if aggressive:
|
||||
model_management.unload_all_models()
|
||||
print('Upscale script freed all memory.')
|
||||
return
|
||||
|
||||
model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device())
|
||||
print('Upscale script freed memory successfully.')
|
||||
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
||||
|
||||
Reference in New Issue
Block a user