initial commit
This commit is contained in:
189
modules/shared_state.py
Executable file
189
modules/shared_state.py
Executable file
@@ -0,0 +1,189 @@
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
|
||||
from modules import errors, shared, devices
|
||||
from backend.args import args
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class State:
|
||||
skipped = False
|
||||
interrupted = False
|
||||
stopping_generation = False
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
server_start = None
|
||||
_server_command_signal = threading.Event()
|
||||
_server_command: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.server_start = time.time()
|
||||
if args.cuda_stream:
|
||||
self.vae_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self.vae_stream = None
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
# Compatibility getter for need_restart.
|
||||
return self.server_command == "restart"
|
||||
|
||||
@need_restart.setter
|
||||
def need_restart(self, value: bool) -> None:
|
||||
# Compatibility setter for need_restart.
|
||||
if value:
|
||||
self.server_command = "restart"
|
||||
|
||||
@property
|
||||
def server_command(self):
|
||||
return self._server_command
|
||||
|
||||
@server_command.setter
|
||||
def server_command(self, value: Optional[str]) -> None:
|
||||
"""
|
||||
Set the server command to `value` and signal that it's been set.
|
||||
"""
|
||||
self._server_command = value
|
||||
self._server_command_signal.set()
|
||||
|
||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
Wait for server command to get set; return and clear the value and signal.
|
||||
"""
|
||||
if self._server_command_signal.wait(timeout):
|
||||
self._server_command_signal.clear()
|
||||
req = self._server_command
|
||||
self._server_command = None
|
||||
return req
|
||||
return None
|
||||
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def stop_generating(self):
|
||||
self.stopping_generation = True
|
||||
log.info("Received stop generating request")
|
||||
|
||||
def nextjob(self):
|
||||
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"stopping_generation": self.stopping_generation,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.time_start = time.time()
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.stopping_generation = False
|
||||
self.textinfo = None
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@torch.inference_mode()
|
||||
def set_current_image(self):
|
||||
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||
if not shared.parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
@torch.inference_mode()
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
|
||||
try:
|
||||
if self.vae_stream is not None:
|
||||
# not waiting on default stream will result in corrupt results
|
||||
# will not block main stream under any circumstances
|
||||
self.vae_stream.wait_stream(torch.cuda.default_stream())
|
||||
vae_context = torch.cuda.stream(self.vae_stream)
|
||||
else:
|
||||
vae_context = nullcontext()
|
||||
with vae_context:
|
||||
if shared.opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
except Exception as e:
|
||||
# traceback.print_exc()
|
||||
# print(e)
|
||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||
# we silently ignore this error
|
||||
errors.record_exception()
|
||||
|
||||
@torch.inference_mode()
|
||||
def assign_current_image(self, image):
|
||||
if shared.opts.live_previews_image_format == 'jpeg' and image.mode in ('RGBA', 'P'):
|
||||
image = image.convert('RGB')
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
Reference in New Issue
Block a user