Add build-in extension "NeverOOM"

see also discussions
This commit is contained in:
lllyasviel
2024-02-24 19:09:06 -08:00
parent 50229a05c1
commit 437c348926
5 changed files with 68 additions and 6 deletions

View File

@@ -204,6 +204,9 @@ elif args.vae_in_fp32:
VAE_DTYPE = torch.float32
VAE_ALWAYS_TILED = False
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)

View File

@@ -208,7 +208,7 @@ class VAE:
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = ldm_patched.modules.utils.ProgressBar(steps)
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode')
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
@@ -222,7 +222,7 @@ class VAE:
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = ldm_patched.modules.utils.ProgressBar(steps)
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode')
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@@ -232,6 +232,9 @@ class VAE:
return samples
def decode(self, samples_in):
if model_management.VAE_ALWAYS_TILED:
return self.decode_tiled(samples_in).to(self.output_device)
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@@ -256,6 +259,9 @@ class VAE:
return output.movedim(1,-1)
def encode(self, pixel_samples):
if model_management.VAE_ALWAYS_TILED:
return self.encode_tiled(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)

View File

@@ -1,5 +1,5 @@
# Taken from https://github.com/comfyanonymous/ComfyUI
# This file is only for reference, and not used in the backend or runtime.
# 1st edit https://github.com/comfyanonymous/ComfyUI
# 2nd edit by Forge
import torch
@@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle
import safetensors.torch
import numpy as np
from PIL import Image
from tqdm import tqdm
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
def __init__(self, total, title=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.tqdm = tqdm(total=total, desc=title)
def update_absolute(self, value, total=None, preview=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
inc = value - self.current
self.tqdm.update(inc)
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
if self.current >= self.total:
self.tqdm.close()
def update(self, value):
self.update_absolute(self.current + value)