mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 03:31:30 +00:00
Add build-in extension "NeverOOM"
see also discussions
This commit is contained in:
@@ -204,6 +204,9 @@ elif args.vae_in_fp32:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
|
||||
VAE_ALWAYS_TILED = False
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
|
||||
@@ -208,7 +208,7 @@ class VAE:
|
||||
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode')
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
@@ -222,7 +222,7 @@ class VAE:
|
||||
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode')
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@@ -232,6 +232,9 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
if model_management.VAE_ALWAYS_TILED:
|
||||
return self.decode_tiled(samples_in).to(self.output_device)
|
||||
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@@ -256,6 +259,9 @@ class VAE:
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
if model_management.VAE_ALWAYS_TILED:
|
||||
return self.encode_tiled(pixel_samples)
|
||||
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Taken from https://github.com/comfyanonymous/ComfyUI
|
||||
# This file is only for reference, and not used in the backend or runtime.
|
||||
# 1st edit https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge
|
||||
|
||||
|
||||
import torch
|
||||
@@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, title=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.tqdm = tqdm(total=total, desc=title)
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
inc = value - self.current
|
||||
self.tqdm.update(inc)
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
if self.current >= self.total:
|
||||
self.tqdm.close()
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
Reference in New Issue
Block a user