mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-03 22:58:58 +00:00
This will move all major gradio calls into the main thread rather than random gradio threads. This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments. In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
174 lines
5.5 KiB
Python
174 lines
5.5 KiB
Python
import datetime
|
|
import logging
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import torch
|
|
|
|
from modules import errors, shared, devices
|
|
from typing import Optional
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class State:
|
|
skipped = False
|
|
interrupted = False
|
|
stopping_generation = False
|
|
job = ""
|
|
job_no = 0
|
|
job_count = 0
|
|
processing_has_refined_job_count = False
|
|
job_timestamp = '0'
|
|
sampling_step = 0
|
|
sampling_steps = 0
|
|
current_latent = None
|
|
current_image = None
|
|
current_image_sampling_step = 0
|
|
id_live_preview = 0
|
|
textinfo = None
|
|
time_start = None
|
|
server_start = None
|
|
_server_command_signal = threading.Event()
|
|
_server_command: Optional[str] = None
|
|
|
|
def __init__(self):
|
|
self.server_start = time.time()
|
|
|
|
@property
|
|
def need_restart(self) -> bool:
|
|
# Compatibility getter for need_restart.
|
|
return self.server_command == "restart"
|
|
|
|
@need_restart.setter
|
|
def need_restart(self, value: bool) -> None:
|
|
# Compatibility setter for need_restart.
|
|
if value:
|
|
self.server_command = "restart"
|
|
|
|
@property
|
|
def server_command(self):
|
|
return self._server_command
|
|
|
|
@server_command.setter
|
|
def server_command(self, value: Optional[str]) -> None:
|
|
"""
|
|
Set the server command to `value` and signal that it's been set.
|
|
"""
|
|
self._server_command = value
|
|
self._server_command_signal.set()
|
|
|
|
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
|
"""
|
|
Wait for server command to get set; return and clear the value and signal.
|
|
"""
|
|
if self._server_command_signal.wait(timeout):
|
|
self._server_command_signal.clear()
|
|
req = self._server_command
|
|
self._server_command = None
|
|
return req
|
|
return None
|
|
|
|
def request_restart(self) -> None:
|
|
self.interrupt()
|
|
self.server_command = "restart"
|
|
log.info("Received restart request")
|
|
|
|
def skip(self):
|
|
self.skipped = True
|
|
log.info("Received skip request")
|
|
|
|
def interrupt(self):
|
|
self.interrupted = True
|
|
log.info("Received interrupt request")
|
|
|
|
def stop_generating(self):
|
|
self.stopping_generation = True
|
|
log.info("Received stop generating request")
|
|
|
|
def nextjob(self):
|
|
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
|
self.do_set_current_image()
|
|
|
|
self.job_no += 1
|
|
self.sampling_step = 0
|
|
self.current_image_sampling_step = 0
|
|
|
|
def dict(self):
|
|
obj = {
|
|
"skipped": self.skipped,
|
|
"interrupted": self.interrupted,
|
|
"stopping_generation": self.stopping_generation,
|
|
"job": self.job,
|
|
"job_count": self.job_count,
|
|
"job_timestamp": self.job_timestamp,
|
|
"job_no": self.job_no,
|
|
"sampling_step": self.sampling_step,
|
|
"sampling_steps": self.sampling_steps,
|
|
}
|
|
|
|
return obj
|
|
|
|
def begin(self, job: str = "(unknown)"):
|
|
self.sampling_step = 0
|
|
self.time_start = time.time()
|
|
self.job_count = -1
|
|
self.processing_has_refined_job_count = False
|
|
self.job_no = 0
|
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
self.current_latent = None
|
|
self.current_image = None
|
|
self.current_image_sampling_step = 0
|
|
self.id_live_preview = 0
|
|
self.skipped = False
|
|
self.interrupted = False
|
|
self.stopping_generation = False
|
|
self.textinfo = None
|
|
self.job = job
|
|
devices.torch_gc()
|
|
log.info("Starting job %s", job)
|
|
|
|
def end(self):
|
|
duration = time.time() - self.time_start
|
|
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
|
self.job = ""
|
|
self.job_count = 0
|
|
|
|
devices.torch_gc()
|
|
|
|
@torch.inference_mode()
|
|
def set_current_image(self):
|
|
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
|
if not shared.parallel_processing_allowed:
|
|
return
|
|
|
|
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
|
self.do_set_current_image()
|
|
|
|
@torch.inference_mode()
|
|
def do_set_current_image(self):
|
|
if self.current_latent is None:
|
|
return
|
|
|
|
import modules.sd_samplers
|
|
|
|
try:
|
|
if shared.opts.show_progress_grid:
|
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
|
else:
|
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
|
|
|
self.current_image_sampling_step = self.sampling_step
|
|
|
|
except Exception as e:
|
|
# traceback.print_exc()
|
|
# print(e)
|
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
|
# we silently ignore this error
|
|
errors.record_exception()
|
|
|
|
@torch.inference_mode()
|
|
def assign_current_image(self, image):
|
|
self.current_image = image
|
|
self.id_live_preview += 1
|