Merge branch 'main' into upt

This commit is contained in:
lllyasviel
2024-02-11 16:42:36 -08:00
committed by GitHub
1005 changed files with 154843 additions and 3044 deletions

View File

@@ -2,11 +2,13 @@ import base64
import io
import os
import time
import itertools
import datetime
import uvicorn
import ipaddress
import requests
import gradio as gr
import numpy as np
from threading import Lock
from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response
@@ -103,6 +105,8 @@ def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
@@ -480,7 +484,11 @@ class Api:
shared.state.end()
shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
b64images = [
encode_pil_to_base64(image)
for image in itertools.chain(processed.images, processed.extra_images)
if send_images
]
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -547,7 +555,11 @@ class Api:
shared.state.end()
shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
b64images = [
encode_pil_to_base64(image)
for image in itertools.chain(processed.images, processed.extra_images)
if send_images
]
if not img2imgreq.include_init_images:
img2imgreq.init_images = None

View File

@@ -2,8 +2,9 @@ import argparse
import json
import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
from ldm_patched.modules import args_parser
parser = argparse.ArgumentParser()
parser = args_parser.parser
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")

View File

@@ -2,7 +2,7 @@ import os
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler_utils import upscale_with_model
@@ -23,6 +23,7 @@ class UpscalerDAT(Upscaler):
self.scalers.append(model)
def do_upscale(self, img, path):
prepare_free_memory()
try:
info = self.load_model(path)
except Exception:

View File

@@ -4,7 +4,10 @@ import re
import torch
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
from modules import modelloader, paths, deepbooru_model, images, shared
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
re_special = re.compile(r'([\\()])')
@@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])')
class DeepDanbooru:
def __init__(self):
self.model = None
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(device=self.load_device):
self.dtype = torch.float16
self.patcher = None
def load(self):
if self.model is not None:
@@ -28,16 +39,16 @@ class DeepDanbooru:
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
self.model.to(devices.cpu, devices.dtype)
self.model.to(self.offload_device, self.dtype)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
def start(self):
self.load()
self.model.to(devices.device)
model_management.load_models_gpu([self.patcher])
def stop(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model.to(devices.cpu)
devices.torch_gc()
pass
def tag(self, pil_image):
self.start()
@@ -56,8 +67,8 @@ class DeepDanbooru:
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
with torch.no_grad():
x = torch.from_numpy(a).to(self.load_device, self.dtype)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}

View File

@@ -1,228 +1,93 @@
import sys
import contextlib
from functools import lru_cache
import torch
from modules import errors, shared, npu_specific
if sys.platform == "darwin":
from modules import mac_specific
if shared.cmd_opts.use_ipex:
from modules import xpu_specific
import ldm_patched.modules.model_management as model_management
def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
return model_management.xpu_available
def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_specific.has_mps
return model_management.mps_mode()
def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
return False
def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
return model_management.get_torch_device().index
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
return "cuda"
return str(model_management.get_torch_device())
def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()
if has_mps():
return "mps"
if has_xpu():
return xpu_specific.get_xpu_device_string()
if npu_specific.has_npu:
return npu_specific.get_npu_device_string()
return "cpu"
return model_management.get_torch_device().type
def get_optimal_device():
return torch.device(get_optimal_device_name())
return model_management.get_torch_device()
def get_device_for(task):
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
return model_management.get_torch_device()
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
if has_xpu():
xpu_specific.torch_xpu_gc()
if npu_specific.has_npu:
torch_npu_set_device()
npu_specific.torch_npu_gc()
model_management.soft_empty_cache()
def torch_npu_set_device():
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
if npu_specific.has_npu:
torch.npu.set_device(0)
return
def enable_tf32():
if torch.cuda.is_available():
return
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
device_esrgan: torch.device = None
device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
dtype_inference: torch.dtype = torch.float16
device: torch.device = model_management.get_torch_device()
device_interrogate: torch.device = cpu # not used
device_gfpgan: torch.device = cpu
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
device_codeformer: torch.device = cpu
dtype: torch.dtype = model_management.unet_dtype()
dtype_vae: torch.dtype = model_management.vae_dtype()
dtype_unet: torch.dtype = model_management.unet_dtype()
dtype_inference: torch.dtype = model_management.unet_dtype()
unet_needs_upcast = False
def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input
return input
def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
return input
nv_rng = None
patch_module_list = [
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]
patch_module_list = []
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
org_dtype = target_dtype
for param in self.parameters():
if param.dtype != target_dtype:
org_dtype = param.dtype
break
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
self.to(org_dtype)
if target_dtype != dtype_inference:
if isinstance(result, tuple):
result = tuple(
i.to(dtype_inference)
if isinstance(i, torch.Tensor)
else i
for i in result
)
elif isinstance(result, torch.Tensor):
result = result.to(dtype_inference)
return result
return forward_wrapper
return
@contextlib.contextmanager
def manual_cast(target_dtype):
applied = False
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32)
else:
module_type.forward = manual_cast_forward(target_dtype)
module_type.org_forward = org_forward
try:
yield None
finally:
if applied:
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
module_type.forward = module_type.org_forward
delattr(module_type, "org_forward")
return
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return torch.autocast("cuda")
return contextlib.nullcontext()
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
return contextlib.nullcontext()
class NansException(Exception):
@@ -230,42 +95,8 @@ class NansException(Exception):
def test_for_nans(x, where):
if shared.cmd_opts.disable_nan_check:
return
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
return
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
return

View File

@@ -1,6 +1,6 @@
from modules import modelloader, devices, errors
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler_utils import upscale_with_model
@@ -27,6 +27,7 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
prepare_free_memory()
try:
model = self.load_model(selected_model)
except Exception:

View File

@@ -8,6 +8,7 @@ import re
from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
from modules_forge.config import always_disabled_extensions
os.makedirs(extensions_dir, exist_ok=True)
@@ -218,7 +219,17 @@ def list_extensions():
continue
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
disabled_extensions = shared.opts.disabled_extensions + always_disabled_extensions
extension = Extension(
name=extension_dirname,
path=path,
enabled=extension_dirname not in disabled_extensions,
is_builtin=is_builtin,
metadata=metadata
)
extensions.append(extension)
loaded_extensions[canonical_name] = extension

View File

@@ -3,7 +3,7 @@ import sys
from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler_utils import upscale_with_model
@@ -20,6 +20,7 @@ class UpscalerHAT(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
prepare_free_memory()
try:
model = self.load_model(selected_model)
except Exception as e:

View File

@@ -15,6 +15,7 @@ import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
from modules_forge import main_thread
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
@@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -243,4 +244,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args)

View File

@@ -138,6 +138,8 @@ def register_paste_params_button(binding: ParamBinding):
def connect_paste_params_buttons():
for binding in registered_param_bindings:
if binding.tabname not in paste_fields:
continue
destination_image_component = paste_fields[binding.tabname]["init_img"]
fields = paste_fields[binding.tabname]["fields"]
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]

View File

@@ -3,11 +3,23 @@ import logging
import os
import sys
import warnings
import os
from threading import Thread
from modules.timer import startup_timer
class HiddenPrints:
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@@ -23,14 +35,16 @@ def imports():
import gradio # noqa: F401
startup_timer.record("import gradio")
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
with HiddenPrints():
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
import ldm.modules.encoders.modules # noqa: F401
import ldm.modules.diffusionmodules.model
startup_timer.record("import ldm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
from modules import shared_init
shared_init.initialize()
@@ -135,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False):
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.
After it's available, if it has been loaded before this access by some extension,
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
from modules import devices
devices.torch_npu_set_device()
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start()
from modules_forge import main_thread
import modules.sd_models
main_thread.async_run(modules.sd_models.model_data.get_sd_model)
from modules import shared_items
shared_items.reload_hypernetworks()

View File

@@ -170,10 +170,11 @@ def configure_sigint_handler():
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
from modules_forge import main_thread
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)

View File

@@ -10,7 +10,10 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
from modules import devices, paths, shared, modelloader, errors
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -53,7 +56,16 @@ class InterrogateModels:
self.loaded_categories = None
self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(device=self.load_device):
self.dtype = torch.float16
self.blip_patcher = None
self.clip_patcher = None
def categories(self):
if not os.path.exists(self.content_dir):
@@ -105,49 +117,37 @@ class InterrogateModels:
def load_clip_model(self):
import clip
import clip.model
if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else:
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
clip.model.LayerNorm = torch.nn.LayerNorm
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
model.eval()
model = model.to(devices.device_interrogate)
return model, preprocess
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
self.blip_model = self.blip_model.to(device=self.offload_device, dtype=self.dtype)
self.blip_patcher = ModelPatcher(self.blip_model, load_device=self.load_device, offload_device=self.offload_device)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = torch_utils.get_param(self.clip_model).dtype
model_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
return
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)
pass
def send_blip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)
pass
def unload(self):
self.send_clip_to_ram()
self.send_blip_to_ram()
devices.torch_gc()
pass
def rank(self, image_features, text_array, top_count=1):
import clip
@@ -158,11 +158,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_tokens = clip.tokenize(list(text_array), truncate=True).to(self.load_device)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
similarity = torch.zeros((1, len(text_array))).to(self.load_device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -175,7 +175,7 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
])(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@@ -186,9 +186,6 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load()
caption = self.generate_caption(pil_image)
@@ -197,7 +194,7 @@ class InterrogateModels:
res = caption
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(self.load_device)
with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)

View File

@@ -12,9 +12,12 @@ import json
from functools import lru_cache
from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
from modules.paths_internal import script_path, extensions_dir, extensions_builtin_dir
from modules.timer import startup_timer
from modules import logging_config
from modules_forge import forge_version
from modules_forge.config import always_disabled_extensions
args, _ = cmd_args.parser.parse_known_args()
logging_config.setup_logging(args.loglevel)
@@ -70,7 +73,7 @@ def commit_hash():
@lru_cache()
def git_tag():
def git_tag_a1111():
try:
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
@@ -85,6 +88,10 @@ def git_tag():
return "<none>"
def git_tag():
return 'f' + forge_version.version + '-' + git_tag_a1111()
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None:
print(desc)
@@ -252,7 +259,7 @@ def list_extensions(settings_file):
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
disabled_extensions = set(settings.get('disabled_extensions', []))
disabled_extensions = set(settings.get('disabled_extensions', []) + always_disabled_extensions)
disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
@@ -261,6 +268,27 @@ def list_extensions(settings_file):
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
def list_extensions_builtin(settings_file):
settings = {}
try:
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
except FileNotFoundError:
pass
except Exception:
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_builtin_dir):
return []
return [x for x in os.listdir(extensions_builtin_dir) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
if not os.path.isdir(extensions_dir):
return
@@ -275,6 +303,21 @@ def run_extensions_installers(settings_file):
run_extension_installer(path)
startup_timer.record(dirname_extension)
if not os.path.isdir(extensions_builtin_dir):
return
with startup_timer.subcategory("run extensions_builtin installers"):
for dirname_extension in list_extensions_builtin(settings_file):
logging.debug(f"Installing {dirname_extension}")
path = os.path.join(extensions_builtin_dir, dirname_extension)
if os.path.isdir(path):
run_extension_installer(path)
startup_timer.record(dirname_extension)
return
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
@@ -468,6 +511,11 @@ def start():
else:
webui.webui()
from modules_forge import main_thread
main_thread.loop()
return
def dump_sysinfo():
from modules import sysinfo

View File

@@ -6,142 +6,20 @@ cpu = torch.device("cpu")
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
return
def is_needed(sd_model):
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
return False
def apply(sd_model):
enable = is_needed(sd_model)
shared.parallel_processing_allowed = not enable
if enable:
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
else:
sd_model.lowvram = False
return
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
sd_model.lowvram = True
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
first_stage_model = sd_model.first_stage_model
first_stage_model_encode = sd_model.first_stage_model.encode
first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
to_remain_in_cpu = [
(sd_model, 'first_stage_model'),
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
(sd_model, 'embedder'),
]
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
else:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)
# send the model to GPU.
sd_model.to(devices.device)
# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)
# register hooks for those the first three models
if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
return
def is_enabled(sd_model):
return sd_model.lowvram
return False

View File

@@ -445,7 +445,7 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
return x0.to(x)
def model_fn(self, x, t):
"""

View File

@@ -62,3 +62,13 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
import ldm_patched.utils.path_utils as ldm_patched_path_utils
ldm_patched_path_utils.base_path = data_path
ldm_patched_path_utils.models_dir = models_path
ldm_patched_path_utils.output_directory = os.path.join(data_path, "output")
ldm_patched_path_utils.temp_directory = os.path.join(data_path, "temp")
ldm_patched_path_utils.input_directory = os.path.join(data_path, "input")
ldm_patched_path_utils.user_directory = os.path.join(data_path, "user")

View File

@@ -256,6 +256,9 @@ class StableDiffusionProcessing:
self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c
self.extra_result_images = []
self.modified_noise = None
@property
def sd_model(self):
return shared.sd_model
@@ -515,8 +518,9 @@ class StableDiffusionProcessing:
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments="", extra_images_list=[]):
self.images = images_list
self.extra_images = extra_images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
self.seed = seed
@@ -628,44 +632,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
if shared.opts.auto_vae_precision_bfloat16:
autofix_dtype = torch.bfloat16
autofix_dtype_text = "bfloat16"
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
autofix_dtype_comment = ""
elif shared.opts.auto_vae_precision:
autofix_dtype = torch.float32
autofix_dtype_text = "32-bit float"
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
else:
raise e
if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)
sample = decode_first_stage(model, batch[i:i + 1])[0]
if target_device is not None:
sample = sample.to(target_device)
samples.append(sample)
samples.append(sample.to(target_device))
return samples
@@ -848,7 +815,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with torch.inference_mode():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -872,6 +839,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
sd_models.reload_model_weights() # model can be changed for example by refiner
p.sd_model.forge_objects = p.sd_model.forge_objects_original.shallow_copy()
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -888,8 +856,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(p, p.extra_network_data)
extra_networks.activate(p, p.extra_network_data)
p.sd_model.forge_objects = p.sd_model.forge_objects_after_applying_lora.shallow_copy()
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -941,8 +910,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', [])
alphas_cumprod_backup = None
if len(alphas_cumprod_modifiers) > 0:
alphas_cumprod_backup = p.sd_model.alphas_cumprod
for modifier in alphas_cumprod_modifiers:
p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod)
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if alphas_cumprod_backup is not None:
p.sd_model.alphas_cumprod = alphas_cumprod_backup
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
@@ -961,9 +940,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
state.nextjob()
@@ -1102,6 +1078,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
extra_images_list=p.extra_result_images,
)
if p.scripts is not None:
@@ -1270,7 +1247,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
@@ -1283,6 +1260,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
# here we generate an image normally
x = self.rng.next()
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=x,
noise=x,
c=conditioning,
uc=unconditional_conditioning)
if self.modified_noise is not None:
x = self.modified_noise
self.modified_noise = None
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
@@ -1354,7 +1344,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
@@ -1384,6 +1374,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scripts is not None:
self.scripts.before_hr(self)
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=samples,
noise=noise,
c=self.hr_c,
uc=self.hr_uc)
if self.modified_noise is not None:
noise = self.modified_noise
self.modified_noise = None
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
@@ -1459,7 +1461,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()
elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
@@ -1646,7 +1648,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
@@ -1687,6 +1689,18 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=self.init_latent,
noise=x,
c=conditioning,
uc=unconditional_conditioning)
if self.modified_noise is not None:
x = self.modified_noise
self.modified_noise = None
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:

View File

@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class DictWithShape(dict):
def __init__(self, x, shape):
def __init__(self, x):
super().__init__()
self.update(x)
@@ -276,6 +276,19 @@ class DictWithShape(dict):
def shape(self):
return self["crossattn"].shape
def to(self, *args, **kwargs):
for k in self.keys():
if isinstance(self[k], torch.Tensor):
self[k] = self[k].to(*args, **kwargs)
return self
def advanced_indexing(self, item):
result = {}
for k in self.keys():
if isinstance(self[k], torch.Tensor):
result[k] = self[k][item]
return DictWithShape(result)
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
@@ -284,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
if is_dict:
dict_cond = param
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
res = DictWithShape(res)
else:
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
@@ -342,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
if isinstance(tensors[0], dict):
keys = list(tensors[0].keys())
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
stacked = DictWithShape(stacked)
else:
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)

View File

@@ -2,7 +2,7 @@ import os
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler_utils import upscale_with_model
@@ -27,6 +27,8 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
def do_upscale(self, img, path):
prepare_free_memory()
if not self.enable:
return img

View File

@@ -8,7 +8,13 @@ def randn(seed, shape, generator=None):
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if generator is not None:
# If generator is not none, we must use another seed to
# avoid global torch.rand to get same noise again.
# Note: removing this will make DDPM sampler broken.
manual_seed((seed + 100000) % 65536)
else:
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)

View File

@@ -192,5 +192,4 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None

View File

@@ -140,6 +140,9 @@ callback_map = dict(
callbacks_list_unets=[],
callbacks_before_token_counter=[],
)
event_subscriber_map = dict(
callbacks_setting_updated=[],
)
def clear_callbacks():
@@ -328,6 +331,23 @@ def before_token_counter_callback(params: BeforeTokenCounterParams):
report_exception(c, 'before_token_counter')
def setting_updated_event_subscriber_chain(handler, component, setting_name: str):
"""
Arguments:
- handler: The returned handler from calling an event subscriber.
- component: The component that is updated. The component should provide
the value of setting after update.
- setting_name: The name of the setting.
"""
for param in event_subscriber_map['callbacks_setting_updated']:
handler = handler.then(
fn=lambda *args: param["fn"](*args, setting_name),
inputs=param["inputs"] + [component],
outputs=param["outputs"],
show_progress=False,
)
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
@@ -509,3 +529,14 @@ def on_before_token_counter(callback):
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
add_callback(callback_map['callbacks_before_token_counter'], callback)
def on_setting_updated_subscriber(subscriber_params):
"""register a function to be called after settings update. `subscriber_params`
should contain necessary fields to register an gradio event handler. Necessary
fields are ["fn", "outputs", "inputs"].
Setting name and setting value after update will be append to inputs. So be
sure to handle these extra params when defining the callback function.
"""
event_subscriber_map['callbacks_setting_updated'].append(subscriber_params)

View File

@@ -186,6 +186,14 @@ class Script:
"""
pass
def process_before_every_sampling(self, p, *args, **kwargs):
"""
Similar to process(), called before every sampling.
If you use high-res fix, this will be called two times.
"""
pass
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -504,8 +512,14 @@ def load_scripts():
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
for s in scripts_list:
if s.basedir not in sys.path:
sys.path = [s.basedir] + sys.path
syspath = sys.path
# print(f'Current System Paths = {syspath}')
def register_scripts_from_module(module):
for script_class in module.__dict__.values():
if not inspect.isclass(script_class):
@@ -809,6 +823,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def process_before_every_sampling(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_before_every_sampling(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
try:

View File

@@ -57,57 +57,11 @@ def list_optimizers():
def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
if len(optimizers) == 0:
# a script can access the model very early, and optimizations would not be filled by then
current_optimizer = None
return ''
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
sgm.modules.diffusionmodules.model.nonlinearity = silu
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
if selection == "None":
matching_optimizer = None
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
matching_optimizer = None
elif matching_optimizer is None:
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
print("Disabling attention optimization")
return ''
return
def undo_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
return
def fix_checkpoint():
@@ -182,156 +136,30 @@ class StableDiffusionModelHijack:
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
try:
self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
pass
def convert_sdxl_to_ssd(self, m):
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
delattr(m.model.diffusion_model.middle_block, '1')
delattr(m.model.diffusion_model.middle_block, '2')
for i in ['9', '8', '7', '6', '5', '4']:
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
devices.torch_gc()
pass
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
text_cond_models = []
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenCLIPEmbedder':
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if len(text_cond_models) == 1:
m.cond_stage_model = text_cond_models[0]
else:
m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_weighted_forward(m)
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
self.apply_optimizations()
self.clip = m.cond_stage_model
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
self.layers = flatten(m)
import modules.models.diffusion.ddpm_edit
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
sd_unet.original_forward = None
pass
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
conditioner.embedders[i] = embedder.wrapped
if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model')
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
undo_optimizations()
undo_weighted_forward(m)
self.apply_circular(False)
self.layers = None
self.clip = None
pass
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
self.circular_enabled = enable
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
pass
def clear_comments(self):
self.comments = []
self.extra_generation_params = {}
def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
def get_prompt_lengths(self, text, cond_stage_model):
_, token_count = cond_stage_model.process_texts([text])
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
self.undo_hijack(m)
self.hijack(m)
pass
class EmbeddingsWithFixes(torch.nn.Module):
@@ -340,6 +168,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
self.wrapped = wrapped
self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
self.weight = self.wrapped.weight
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes

View File

@@ -10,13 +10,19 @@ from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
import gc
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
import numpy as np
from modules_forge import forge_loader
import modules_forge.ops as forge_ops
from ldm_patched.modules.ops import manual_cast
from ldm_patched.modules import model_management as model_management
import ldm_patched.modules.model_patcher
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -150,9 +156,9 @@ def list_models():
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
model_url = None
else:
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors"
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
if os.path.exists(cmd_ckpt):
checkpoint_info = CheckpointInfo(cmd_ckpt)
@@ -366,26 +372,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if devices.fp8:
# prevent model to load state dict in fp8
model.half()
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
if model.is_ssd:
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
@@ -395,65 +387,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if shared.cmd_opts.no_half:
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
devices.dtype_unet = torch.float16
timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
@@ -590,14 +523,6 @@ class SdModelData:
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
except ValueError:
pass
if v is not None:
self.loaded_sd_models.insert(0, v)
model_data = SdModelData()
@@ -615,31 +540,19 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
devices.torch_gc()
pass
def model_target_device(m):
if lowvram.is_needed(m):
return devices.cpu
else:
return devices.device
return devices.device
def send_model_to_device(m):
lowvram.apply(m)
if not m.lowvram:
m.to(shared.device)
pass
def send_model_to_trash(m):
m.to(device="meta")
devices.torch_gc()
pass
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
@@ -649,9 +562,14 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer = Timer()
if model_data.sd_model:
send_model_to_trash(model_data.sd_model)
if model_data.sd_model.filename == checkpoint_info.filename:
return model_data.sd_model
model_data.sd_model = None
devices.torch_gc()
model_data.loaded_sd_models = []
model_management.unload_all_models()
model_management.soft_empty_cache()
gc.collect()
timer.record("unload existing model")
@@ -660,58 +578,27 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
timer.record("find config")
sd_model = forge_loader.load_model_for_a1111(timer=timer, checkpoint_info=checkpoint_info, state_dict=state_dict)
sd_model.filename = checkpoint_info.filename
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
del state_dict
timer.record("load config")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
print(f"Creating model from config: {checkpoint_config}")
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(sd_model, vae_file, vae_source)
timer.record("load VAE")
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'alphas_cumprod': None,
'': torch.float16,
}
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
send_model_to_device(sd_model)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
sd_model.eval()
model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
@@ -723,7 +610,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
with torch.no_grad():
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")
@@ -734,156 +621,20 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
If not, returns the model that can be used to load weights from checkpoint_info's file.
If no such model exists, returns None.
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
"""
already_loaded = None
for i in reversed(range(len(model_data.loaded_sd_models))):
loaded_model = model_data.loaded_sd_models[i]
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
already_loaded = loaded_model
continue
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
model_data.loaded_sd_models.pop()
send_model_to_trash(loaded_model)
timer.record("send model to trash")
if shared.opts.sd_checkpoints_keep_in_cpu:
send_model_to_cpu(sd_model)
timer.record("send model to cpu")
if already_loaded is not None:
send_model_to_device(already_loaded)
timer.record("send model to device")
model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
model_data.sd_model = None
load_model(checkpoint_info)
return model_data.sd_model
elif len(model_data.loaded_sd_models) > 0:
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
return None
pass
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
if not sd_model:
sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if check_fp8(sd_model) != devices.fp8:
# load from state dict again to prevent extra numerical errors
forced_reload = True
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
sd_unet.apply_unet("None")
send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
if sd_model is not None:
send_model_to_trash(sd_model)
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
finally:
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
sd_unet.apply_unet()
return sd_model
return load_model(info)
def unload_model_weights(sd_model=None, info=None):
send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
def apply_token_merging(sd_model, token_merging_ratio):
"""
Applies speed and memory optimizations from tomesd.
"""
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
if current_token_merging_ratio == token_merging_ratio:
return
if current_token_merging_ratio > 0:
tomesd.remove_patch(sd_model)
if token_merging_ratio > 0:
tomesd.apply_patch(
sd_model,
ratio=token_merging_ratio,
use_rand=False, # can cause issues with some samplers
merge_attn=True,
merge_crossattn=False,
merge_mlp=False
)
print('Token merging is under construction now and the setting will not take effect.')
sd_model.applied_token_merged_ratio = token_merging_ratio
# TODO: rework using new UNet patcher system
return

View File

@@ -8,8 +8,13 @@ import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
from modules import torch_utils
import ldm_patched.modules.model_management as model_management
from modules_forge.forge_clip import move_clip_to_gpu
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
move_clip_to_gpu()
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
@@ -18,7 +23,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=devices.device, dtype=devices.dtype)
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=model_management.text_encoder_dtype())
sdxl_conds = {
"txt": batch,
@@ -34,14 +39,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
return c
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs):
if self.model.diffusion_model.in_channels == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond)
return self.model(x, t, cond, *args, **kwargs)
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility

View File

@@ -2,11 +2,13 @@ from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_l
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
from modules_forge import forge_alter_samplers
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps,
*sd_samplers_lcm.samplers_data_lcm,
*forge_alter_samplers.samplers_data_alter
]
all_samplers_map = {x.name: x for x in all_samplers}

View File

@@ -6,6 +6,7 @@ import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
from modules_forge import forge_sampler
def catenate_conds(conds):
@@ -58,15 +59,16 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
# Backward Compatibility
self.mask_before_denoising = False
self.classic_ddim_eps_estimation = False
@property
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -152,142 +154,38 @@ class CFGDenoiser(torch.nn.Module):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
if sd_samplers_common.apply_refiner(self):
original_x_device = x.device
original_x_dtype = x.dtype
if self.classic_ddim_eps_estimation:
acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
real_sigma_data = 1.0
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
sigma = real_sigma
if sd_samplers_common.apply_refiner(self, x):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
if self.mask is not None:
noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)
x = x * self.nmask + noisy_initial_latent * self.mask
# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent
return blended_latent
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else:
image_uncond = image_cond
if isinstance(uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
else:
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
skip_uncond = False
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
skip_uncond = True
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params,
cond_scale=cond_scale, cond_composition=cond_composition)
self.padded_cond_uncond = False
self.padded_cond_uncond_v0 = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
cond_in = catenate_conds([tensor, uncond])
if shared.opts.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt":
preview = self.sampler.last_latent
elif opts.live_preview_content == "Negative prompt":
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
else:
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
if self.mask is not None:
denoised = denoised * self.nmask + self.init_latent * self.mask
preview = self.sampler.last_latent = denoised
sd_samplers_common.store_latent(preview)
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
@@ -295,5 +193,10 @@ class CFGDenoiser(torch.nn.Module):
denoised = after_cfg_callback_params.x
self.step += 1
return denoised
if self.classic_ddim_eps_estimation:
eps = (x - denoised) / sigma[:, None, None, None]
return eps
return denoised.to(device=original_x_device, dtype=original_x_dtype)

View File

@@ -5,6 +5,8 @@ import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
from modules import extra_networks
import k_diffusion.sampling
@@ -39,9 +41,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
from modules import lowvram
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
if approximation == 0:
approximation = 1
if approximation == 2:
@@ -54,8 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
x_sample = model.decode_first_stage(sample)
return x_sample
@@ -71,7 +70,6 @@ def single_sample_to_image(sample, approximation=None):
def decode_first_stage(model, x):
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)
@@ -95,7 +93,6 @@ def images_tensor_to_samples(image, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
model.first_stage_model.to(devices.dtype_vae)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
@@ -155,7 +152,7 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
def apply_refiner(cfg_denoiser):
def apply_refiner(cfg_denoiser, x):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
@@ -181,13 +178,18 @@ def apply_refiner(cfg_denoiser):
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
if not cfg_denoiser.p.disable_extra_networks:
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
cfg_denoiser.p.setup_conds()
cfg_denoiser.update_inner_model()
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
return True

View File

@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
@@ -139,11 +141,18 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
sigmas = self.get_sigmas(p, steps).to(x.device)
sigma_sched = sigmas[steps - t_enc - 1:]
x = x.to(noise)
xi = x + noise * sigma_sched[0]
if opts.img2img_extra_noise > 0:
@@ -189,12 +198,20 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.add_infotext(p)
sampling_cleanup(unet_patcher)
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
steps = steps or p.steps
sigmas = self.get_sigmas(p, steps)
sigmas = self.get_sigmas(p, steps).to(x.device)
if opts.sgm_noise_multiplier:
p.extra_generation_params["SGM noise multiplier"] = True
@@ -235,6 +252,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
self.add_infotext(p)
sampling_cleanup(unet_patcher)
return samples

View File

@@ -27,14 +27,14 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
start = self.sigma_to_t(self.sigma_max)
end = self.sigma_to_t(self.sigma_min)
t = torch.linspace(start, end, n, device=shared.sd_model.device)
t = torch.linspace(start, end, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
dists = log_sigma - self.log_sigmas.to(sigma)[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)

View File

@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
@@ -50,10 +52,11 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
self.classic_ddim_eps_estimation = True
def get_pred_x0(self, x_in, x_out, sigma):
ts = sigma.to(dtype=int)
self.alphas = self.alphas.to(ts.device)
a_t = self.alphas[ts][:, None, None, None]
sqrt_one_minus_at = (1 - a_t).sqrt()
@@ -95,16 +98,21 @@ class CompVisSampler(sd_samplers_common.Sampler):
return timesteps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
timesteps = self.get_timesteps(p, steps)
timesteps = self.get_timesteps(p, steps).to(x.device)
timesteps_sched = timesteps[:t_enc]
alphas_cumprod = shared.sd_model.alphas_cumprod
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
xi = x.to(noise) * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
if opts.img2img_extra_noise > 0:
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
@@ -135,11 +143,18 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.add_infotext(p)
sampling_cleanup(unet_patcher)
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
steps = steps or p.steps
timesteps = self.get_timesteps(p, steps)
timesteps = self.get_timesteps(p, steps).to(x.device)
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -159,6 +174,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.add_infotext(p)
sampling_cleanup(unet_patcher)
return samples

View File

@@ -45,15 +45,8 @@ def apply_unet(option=None):
current_unet_option = new_option
if current_unet_option is None:
current_unet = None
if not shared.sd_model.lowvram:
shared.sd_model.model.diffusion_model.to(devices.device)
return
shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc()
current_unet = current_unet_option.create_unet()
current_unet.option = current_unet_option
print(f"Activating unet: {current_unet.option.label}")

View File

@@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
@@ -263,20 +262,12 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
if sd_model.lowvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
if not sd_model.lowvram:
sd_model.to(devices.device)
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.")

View File

@@ -24,13 +24,6 @@ def initialize():
pass
from modules import devices
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"

View File

@@ -300,7 +300,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
options_templates.update(options_section(('ui', "User interface", "ui"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"quicksettings_list": OptionInfo(["sd_model_checkpoint", "sd_vae", "CLIP_stop_at_last_layers"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),

View File

@@ -2,6 +2,8 @@ import datetime
import logging
import threading
import time
import traceback
import torch
from modules import errors, shared, devices
from typing import Optional
@@ -134,6 +136,7 @@ class State:
devices.torch_gc()
@torch.inference_mode()
def set_current_image(self):
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
if not shared.parallel_processing_allowed:
@@ -142,6 +145,7 @@ class State:
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
@torch.inference_mode()
def do_set_current_image(self):
if self.current_latent is None:
return
@@ -156,11 +160,14 @@ class State:
self.current_image_sampling_step = self.sampling_step
except Exception:
except Exception as e:
# traceback.print_exc()
# print(e)
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
@torch.inference_mode()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1

View File

@@ -9,6 +9,7 @@ import modules.shared as shared
from modules.ui import plaintext_to_html
from PIL import Image
import gradio as gr
from modules_forge import main_thread
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
@@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
return p
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
@@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def txt2img(id_task: str, request: gr.Request, *args):
def txt2img_function(id_task: str, request: gr.Request, *args):
p = txt2img_create_processing(id_task, request, *args)
with closing(p):
@@ -118,4 +119,12 @@ def txt2img(id_task: str, request: gr.Request, *args):
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args)
def txt2img(id_task: str, request: gr.Request, *args):
return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args)

View File

@@ -178,9 +178,15 @@ def update_token_counter(text, steps, styles, *, is_positive=True):
# messages related to it in console
prompt_schedules = [[[steps, text]]]
try:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
assert cond_stage_model is not None
except Exception:
return f"<span class='gr-box gr-text-input'>?/?</span>"
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], key=lambda args: args[0])
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"

View File

@@ -194,7 +194,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None, object_fit='contain')
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
@@ -206,7 +206,8 @@ Requested path was: {f}
buttons = {
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab."),
'svd': ToolButton('🎬', elem_id=f'{tabname}_send_to_svd', tooltip="Send image and generation parameters to SVD tab."),
}
if tabname == 'txt2img':

View File

@@ -294,7 +294,6 @@ class UiSettings:
for _i, k, _item in self.quicksettings_list:
component = self.component_dict[k]
info = opts.data_labels[k]
if isinstance(component, gr.Textbox):
methods = [component.submit, component.blur]
@@ -304,20 +303,30 @@ class UiSettings:
methods = [component.change]
for method in methods:
method(
handler = method(
fn=lambda value, k=k: self.run_settings_single(value, key=k),
inputs=[component],
outputs=[component, self.text_settings],
show_progress=info.refresh is not None,
show_progress=False,
)
script_callbacks.setting_updated_event_subscriber_chain(
handler=handler,
component=component,
setting_name=k,
)
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
handler = button_set_checkpoint.click(
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
)
script_callbacks.setting_updated_event_subscriber_chain(
handler=handler,
component=self.component_dict['sd_model_checkpoint'],
setting_name="sd_model_checkpoint"
)
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]

View File

@@ -6,6 +6,18 @@ from PIL import Image
import modules.shared
from modules import modelloader, shared
from ldm_patched.modules import model_management
def prepare_free_memory(aggressive=False):
if aggressive:
model_management.unload_all_models()
print('Upscale script freed all memory.')
return
model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device())
print('Upscale script freed memory successfully.')
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)