Add separate cuda stream for live preview VAE (#2844)

This commit is contained in:
drhead
2025-05-01 13:16:54 -04:00
committed by GitHub
parent c055f2d43b
commit d3573962bd
2 changed files with 23 additions and 8 deletions

View File

@@ -70,9 +70,12 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
x_sample = x_sample.cpu()
x_sample.clamp_(0.0, 1.0)
x_sample.mul_(255.)
x_sample.round_()
x_sample = x_sample.to(torch.uint8)
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
return Image.fromarray(x_sample)

View File

@@ -4,8 +4,10 @@ import threading
import time
import traceback
import torch
from contextlib import nullcontext
from modules import errors, shared, devices
from backend.args import args
from typing import Optional
log = logging.getLogger(__name__)
@@ -34,6 +36,10 @@ class State:
def __init__(self):
self.server_start = time.time()
if args.cuda_stream:
self.vae_stream = torch.cuda.Stream()
else:
self.vae_stream = None
@property
def need_restart(self) -> bool:
@@ -153,12 +159,18 @@ class State:
import modules.sd_samplers
try:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
if self.vae_stream is not None:
# not waiting on default stream will result in corrupt results
# will not block main stream under any circumstances
self.vae_stream.wait_stream(torch.cuda.default_stream())
vae_context = torch.cuda.stream(self.vae_stream)
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
vae_context = nullcontext()
with vae_context:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
except Exception as e:
# traceback.print_exc()