Add separate cuda stream for live preview VAE (#2844)

This commit is contained in:
drhead
2025-05-01 13:16:54 -04:00
committed by GitHub
parent c055f2d43b
commit d3573962bd
2 changed files with 23 additions and 8 deletions

View File

@@ -70,9 +70,12 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
def single_sample_to_image(sample, approximation=None): def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = x_sample.cpu()
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample.clamp_(0.0, 1.0)
x_sample = x_sample.astype(np.uint8) x_sample.mul_(255.)
x_sample.round_()
x_sample = x_sample.to(torch.uint8)
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)

View File

@@ -4,8 +4,10 @@ import threading
import time import time
import traceback import traceback
import torch import torch
from contextlib import nullcontext
from modules import errors, shared, devices from modules import errors, shared, devices
from backend.args import args
from typing import Optional from typing import Optional
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -34,6 +36,10 @@ class State:
def __init__(self): def __init__(self):
self.server_start = time.time() self.server_start = time.time()
if args.cuda_stream:
self.vae_stream = torch.cuda.Stream()
else:
self.vae_stream = None
@property @property
def need_restart(self) -> bool: def need_restart(self) -> bool:
@@ -153,12 +159,18 @@ class State:
import modules.sd_samplers import modules.sd_samplers
try: try:
if shared.opts.show_progress_grid: if self.vae_stream is not None:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) # not waiting on default stream will result in corrupt results
# will not block main stream under any circumstances
self.vae_stream.wait_stream(torch.cuda.default_stream())
vae_context = torch.cuda.stream(self.vae_stream)
else: else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) vae_context = nullcontext()
with vae_context:
self.current_image_sampling_step = self.sampling_step if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
except Exception as e: except Exception as e:
# traceback.print_exc() # traceback.print_exc()