mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Set step to the last step saved at when exiting
This commit is contained in:
@@ -20,11 +20,13 @@ class UITrainer(SDTrainer):
|
||||
self.is_stopping = False
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Track all async tasks
|
||||
self._async_tasks = []
|
||||
# Initialize the status
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine in a new event loop."""
|
||||
"""Helper method to run an async coroutine and track the task."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
@@ -32,14 +34,14 @@ class UITrainer(SDTrainer):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
# Create a task and track it
|
||||
if loop.is_running():
|
||||
# If we're already in an event loop, create a future
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
# We could wait for the result if needed: future.result()
|
||||
task = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
self._async_tasks.append(asyncio.wrap_future(task))
|
||||
else:
|
||||
# If no loop is running, run the coroutine and close the loop
|
||||
loop.run_until_complete(coro)
|
||||
task = loop.create_task(coro)
|
||||
self._async_tasks.append(task)
|
||||
loop.run_until_complete(task)
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
@@ -61,7 +63,6 @@ class UITrainer(SDTrainer):
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
# For this one we need a synchronous result, so we'll run it directly
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
@@ -71,29 +72,36 @@ class UITrainer(SDTrainer):
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_step(self):
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET step = ? WHERE id = ?",
|
||||
(self.step_num, self.job_id)
|
||||
f"UPDATE Job SET {key} = ? WHERE id = ?",
|
||||
(value, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_step())
|
||||
self._run_async_operation(self._update_key("step", self.step_num))
|
||||
|
||||
|
||||
def update_db_key(self, key, value):
|
||||
"""Non-blocking update a key in the database."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key(key, value))
|
||||
|
||||
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
@@ -102,7 +110,7 @@ class UITrainer(SDTrainer):
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
@@ -115,26 +123,40 @@ class UITrainer(SDTrainer):
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_status(status, info))
|
||||
|
||||
async def wait_for_all_async(self):
|
||||
"""Wait for all tracked async operations to complete."""
|
||||
if not self._async_tasks:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._async_tasks)
|
||||
finally:
|
||||
# Clear the task list after completion
|
||||
self._async_tasks.clear()
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(UITrainer, self).on_error(e)
|
||||
if self.accelerator.is_main_process and not self.is_stopping:
|
||||
self.update_status("error", str(e))
|
||||
self.update_db_key("step", self.last_save_step)
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def done_hook(self):
|
||||
super(UITrainer, self).done_hook()
|
||||
self.update_status("completed", "Training completed")
|
||||
# Make sure we clean up the thread pool
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
# Wait for all async operations to finish before shutting down
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def end_step_hook(self):
|
||||
super(UITrainer, self).end_step_hook()
|
||||
@@ -154,16 +176,16 @@ class UITrainer(SDTrainer):
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
self.maybe_stop()
|
||||
self.update_step()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
|
||||
def status_update_hook_func(self, string):
|
||||
self.update_status("running", string)
|
||||
|
||||
|
||||
def hook_after_sd_init_before_load(self):
|
||||
super().hook_after_sd_init_before_load()
|
||||
self.maybe_stop()
|
||||
self.sd.add_status_update_hook(self.status_update_hook_func)
|
||||
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
@@ -184,4 +206,4 @@ class UITrainer(SDTrainer):
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
self.update_status("running", "Training")
|
||||
Reference in New Issue
Block a user