mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
Add build-in extension "NeverOOM"
see also discussions
This commit is contained in:
@@ -0,0 +1,47 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts
|
||||||
|
from ldm_patched.modules import model_management
|
||||||
|
|
||||||
|
|
||||||
|
class NeverOOMForForge(scripts.Script):
|
||||||
|
sorting_priority = 18
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.previous_unet_enabled = False
|
||||||
|
self.original_vram_state = model_management.vram_state
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Never OOM Integrated"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, *args, **kwargs):
|
||||||
|
with gr.Accordion(open=False, label=self.title()):
|
||||||
|
unet_enabled = gr.Checkbox(label='Enabled for UNet (always maximize offload)', value=False)
|
||||||
|
vae_enabled = gr.Checkbox(label='Enabled for VAE (always tiled)', value=False)
|
||||||
|
return unet_enabled, vae_enabled
|
||||||
|
|
||||||
|
def process(self, p, *script_args, **kwargs):
|
||||||
|
unet_enabled, vae_enabled = script_args
|
||||||
|
|
||||||
|
if unet_enabled:
|
||||||
|
print('NeverOOM Enabled for UNet (always maximize offload)')
|
||||||
|
|
||||||
|
if vae_enabled:
|
||||||
|
print('NeverOOM Enabled for VAE (always tiled)')
|
||||||
|
|
||||||
|
model_management.VAE_ALWAYS_TILED = vae_enabled
|
||||||
|
|
||||||
|
if self.previous_unet_enabled != unet_enabled:
|
||||||
|
model_management.unload_all_models()
|
||||||
|
if unet_enabled:
|
||||||
|
self.original_vram_state = model_management.vram_state
|
||||||
|
model_management.vram_state = model_management.VRAMState.NO_VRAM
|
||||||
|
else:
|
||||||
|
model_management.vram_state = self.original_vram_state
|
||||||
|
print(f'VARM State Changed To {model_management.vram_state.name}')
|
||||||
|
self.previous_unet_enabled = unet_enabled
|
||||||
|
|
||||||
|
return
|
||||||
@@ -204,6 +204,9 @@ elif args.vae_in_fp32:
|
|||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
VAE_ALWAYS_TILED = False
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ class VAE:
|
|||||||
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode')
|
||||||
|
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||||
output = torch.clamp((
|
output = torch.clamp((
|
||||||
@@ -222,7 +222,7 @@ class VAE:
|
|||||||
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode')
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||||
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
@@ -232,6 +232,9 @@ class VAE:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
|
if model_management.VAE_ALWAYS_TILED:
|
||||||
|
return self.decode_tiled(samples_in).to(self.output_device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
@@ -256,6 +259,9 @@ class VAE:
|
|||||||
return output.movedim(1,-1)
|
return output.movedim(1,-1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
if model_management.VAE_ALWAYS_TILED:
|
||||||
|
return self.encode_tiled(pixel_samples)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Taken from https://github.com/comfyanonymous/ComfyUI
|
# 1st edit https://github.com/comfyanonymous/ComfyUI
|
||||||
# This file is only for reference, and not used in the backend or runtime.
|
# 2nd edit by Forge
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function):
|
|||||||
PROGRESS_BAR_HOOK = function
|
PROGRESS_BAR_HOOK = function
|
||||||
|
|
||||||
class ProgressBar:
|
class ProgressBar:
|
||||||
def __init__(self, total):
|
def __init__(self, total, title=None):
|
||||||
global PROGRESS_BAR_HOOK
|
global PROGRESS_BAR_HOOK
|
||||||
self.total = total
|
self.total = total
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.hook = PROGRESS_BAR_HOOK
|
self.hook = PROGRESS_BAR_HOOK
|
||||||
|
self.tqdm = tqdm(total=total, desc=title)
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
self.total = total
|
self.total = total
|
||||||
if value > self.total:
|
if value > self.total:
|
||||||
value = self.total
|
value = self.total
|
||||||
|
inc = value - self.current
|
||||||
|
self.tqdm.update(inc)
|
||||||
self.current = value
|
self.current = value
|
||||||
if self.hook is not None:
|
if self.hook is not None:
|
||||||
self.hook(self.current, self.total, preview)
|
self.hook(self.current, self.total, preview)
|
||||||
|
if self.current >= self.total:
|
||||||
|
self.tqdm.close()
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
version = '0.0.15v1.8.0rc'
|
version = '0.0.16v1.8.0rc'
|
||||||
|
|||||||
Reference in New Issue
Block a user