diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 03858cdb..065eb714 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -70,9 +70,12 @@ def samples_to_images_tensor(sample, approximation=None, model=None): def single_sample_to_image(sample, approximation=None): x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 - x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) + x_sample = x_sample.cpu() + x_sample.clamp_(0.0, 1.0) + x_sample.mul_(255.) + x_sample.round_() + x_sample = x_sample.to(torch.uint8) + x_sample = np.moveaxis(x_sample.numpy(), 0, 2) return Image.fromarray(x_sample) diff --git a/modules/shared_state.py b/modules/shared_state.py index bdbb4714..699cb0fc 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -4,8 +4,10 @@ import threading import time import traceback import torch +from contextlib import nullcontext from modules import errors, shared, devices +from backend.args import args from typing import Optional log = logging.getLogger(__name__) @@ -34,6 +36,10 @@ class State: def __init__(self): self.server_start = time.time() + if args.cuda_stream: + self.vae_stream = torch.cuda.Stream() + else: + self.vae_stream = None @property def need_restart(self) -> bool: @@ -153,12 +159,18 @@ class State: import modules.sd_samplers try: - if shared.opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + if self.vae_stream is not None: + # not waiting on default stream will result in corrupt results + # will not block main stream under any circumstances + self.vae_stream.wait_stream(torch.cuda.default_stream()) + vae_context = torch.cuda.stream(self.vae_stream) else: - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - - self.current_image_sampling_step = self.sampling_step + vae_context = nullcontext() + with vae_context: + if shared.opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) except Exception as e: # traceback.print_exc()