mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added queing system to the UI
This commit is contained in:
@@ -124,6 +124,19 @@ class DiffusionTrainer(SDTrainer):
|
||||
|
||||
return _check_stop()
|
||||
|
||||
def should_return_to_queue(self):
|
||||
if not self.is_ui_trainer:
|
||||
return False
|
||||
def _check_return_to_queue():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,))
|
||||
return_to_queue = cursor.fetchone()
|
||||
return False if return_to_queue is None else return_to_queue[0] == 1
|
||||
|
||||
return _check_return_to_queue()
|
||||
|
||||
def maybe_stop(self):
|
||||
if not self.is_ui_trainer:
|
||||
return
|
||||
@@ -132,6 +145,11 @@ class DiffusionTrainer(SDTrainer):
|
||||
self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
if self.should_return_to_queue():
|
||||
self._run_async_operation(
|
||||
self._update_status("queued", "Job queued"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job returning to queue")
|
||||
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
|
||||
@@ -115,6 +115,17 @@ class UITrainer(SDTrainer):
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
return _check_stop()
|
||||
|
||||
def should_return_to_queue(self):
|
||||
def _check_return_to_queue():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,))
|
||||
return_to_queue = cursor.fetchone()
|
||||
return False if return_to_queue is None else return_to_queue[0] == 1
|
||||
|
||||
return _check_return_to_queue()
|
||||
|
||||
def maybe_stop(self):
|
||||
if self.should_stop():
|
||||
@@ -122,6 +133,11 @@ class UITrainer(SDTrainer):
|
||||
self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
if self.should_return_to_queue():
|
||||
self._run_async_operation(
|
||||
self._update_status("queued", "Job queued"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job returning to queue")
|
||||
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user