mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Add separate cuda stream for live preview VAE (#2844)
This commit is contained in:
@@ -70,9 +70,12 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
|||||||
def single_sample_to_image(sample, approximation=None):
|
def single_sample_to_image(sample, approximation=None):
|
||||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||||
|
|
||||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
x_sample = x_sample.cpu()
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample.clamp_(0.0, 1.0)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample.mul_(255.)
|
||||||
|
x_sample.round_()
|
||||||
|
x_sample = x_sample.to(torch.uint8)
|
||||||
|
x_sample = np.moveaxis(x_sample.numpy(), 0, 2)
|
||||||
|
|
||||||
return Image.fromarray(x_sample)
|
return Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import torch
|
import torch
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from modules import errors, shared, devices
|
from modules import errors, shared, devices
|
||||||
|
from backend.args import args
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -34,6 +36,10 @@ class State:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.server_start = time.time()
|
self.server_start = time.time()
|
||||||
|
if args.cuda_stream:
|
||||||
|
self.vae_stream = torch.cuda.Stream()
|
||||||
|
else:
|
||||||
|
self.vae_stream = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def need_restart(self) -> bool:
|
def need_restart(self) -> bool:
|
||||||
@@ -153,12 +159,18 @@ class State:
|
|||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if shared.opts.show_progress_grid:
|
if self.vae_stream is not None:
|
||||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
# not waiting on default stream will result in corrupt results
|
||||||
|
# will not block main stream under any circumstances
|
||||||
|
self.vae_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
vae_context = torch.cuda.stream(self.vae_stream)
|
||||||
else:
|
else:
|
||||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
vae_context = nullcontext()
|
||||||
|
with vae_context:
|
||||||
self.current_image_sampling_step = self.sampling_step
|
if shared.opts.show_progress_grid:
|
||||||
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||||
|
else:
|
||||||
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
|
|||||||
Reference in New Issue
Block a user