diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index 2657f60f..fa81c130 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -20,11 +20,13 @@ class UITrainer(SDTrainer): self.is_stopping = False # Create a thread pool for database operations self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Track all async tasks + self._async_tasks = [] # Initialize the status self._run_async_operation(self._update_status("running", "Starting")) def _run_async_operation(self, coro): - """Helper method to run an async coroutine in a new event loop.""" + """Helper method to run an async coroutine and track the task.""" try: loop = asyncio.get_event_loop() except RuntimeError: @@ -32,14 +34,14 @@ class UITrainer(SDTrainer): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # Run the coroutine in the event loop + # Create a task and track it if loop.is_running(): - # If we're already in an event loop, create a future - future = asyncio.run_coroutine_threadsafe(coro, loop) - # We could wait for the result if needed: future.result() + task = asyncio.run_coroutine_threadsafe(coro, loop) + self._async_tasks.append(asyncio.wrap_future(task)) else: - # If no loop is running, run the coroutine and close the loop - loop.run_until_complete(coro) + task = loop.create_task(coro) + self._async_tasks.append(task) + loop.run_until_complete(task) async def _execute_db_operation(self, operation_func): """Execute a database operation in a separate thread to avoid blocking.""" @@ -61,7 +63,6 @@ class UITrainer(SDTrainer): stop = cursor.fetchone() return False if stop is None else stop[0] == 1 - # For this one we need a synchronous result, so we'll run it directly return _check_stop() def maybe_stop(self): @@ -71,29 +72,36 @@ class UITrainer(SDTrainer): self.is_stopping = True raise Exception("Job stopped") - async def _update_step(self): + async def _update_key(self, key, value): if not self.accelerator.is_main_process: return def _do_update(): with self._db_connect() as conn: cursor = conn.cursor() - cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock + cursor.execute("BEGIN IMMEDIATE") try: cursor.execute( - "UPDATE Job SET step = ? WHERE id = ?", - (self.step_num, self.job_id) + f"UPDATE Job SET {key} = ? WHERE id = ?", + (value, self.job_id) ) finally: - cursor.execute("COMMIT") # Release the lock + cursor.execute("COMMIT") await self._execute_db_operation(_do_update) def update_step(self): """Non-blocking update of the step count.""" if self.accelerator.is_main_process: - # Use the helper method to run the async operation - self._run_async_operation(self._update_step()) + self._run_async_operation(self._update_key("step", self.step_num)) + + + def update_db_key(self, key, value): + """Non-blocking update a key in the database.""" + if self.accelerator.is_main_process: + self._run_async_operation(self._update_key(key, value)) + + async def _update_status(self, status: AITK_Status, info: Optional[str] = None): if not self.accelerator.is_main_process: @@ -102,7 +110,7 @@ class UITrainer(SDTrainer): def _do_update(): with self._db_connect() as conn: cursor = conn.cursor() - cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock + cursor.execute("BEGIN IMMEDIATE") try: if info is not None: cursor.execute( @@ -115,26 +123,40 @@ class UITrainer(SDTrainer): (status, self.job_id) ) finally: - cursor.execute("COMMIT") # Release the lock + cursor.execute("COMMIT") await self._execute_db_operation(_do_update) def update_status(self, status: AITK_Status, info: Optional[str] = None): """Non-blocking update of status.""" if self.accelerator.is_main_process: - # Use the helper method to run the async operation self._run_async_operation(self._update_status(status, info)) + async def wait_for_all_async(self): + """Wait for all tracked async operations to complete.""" + if not self._async_tasks: + return + + try: + await asyncio.gather(*self._async_tasks) + finally: + # Clear the task list after completion + self._async_tasks.clear() + def on_error(self, e: Exception): super(UITrainer, self).on_error(e) if self.accelerator.is_main_process and not self.is_stopping: self.update_status("error", str(e)) + self.update_db_key("step", self.last_save_step) + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) def done_hook(self): super(UITrainer, self).done_hook() self.update_status("completed", "Training completed") - # Make sure we clean up the thread pool - self.thread_pool.shutdown(wait=False) + # Wait for all async operations to finish before shutting down + asyncio.run(self.wait_for_all_async()) + self.thread_pool.shutdown(wait=True) def end_step_hook(self): super(UITrainer, self).end_step_hook() @@ -154,16 +176,16 @@ class UITrainer(SDTrainer): def hook_before_train_loop(self): super().hook_before_train_loop() self.maybe_stop() + self.update_step() self.update_status("running", "Training") - + def status_update_hook_func(self, string): self.update_status("running", string) - + def hook_after_sd_init_before_load(self): super().hook_after_sd_init_before_load() self.maybe_stop() self.sd.add_status_update_hook(self.status_update_hook_func) - def sample_step_hook(self, img_num, total_imgs): super().sample_step_hook(img_num, total_imgs) @@ -184,4 +206,4 @@ class UITrainer(SDTrainer): self.update_status("running", "Saving model") super().save(step) self.maybe_stop() - self.update_status("running", "Training") + self.update_status("running", "Training") \ No newline at end of file diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d22dd39d..2482c26d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.step_num = 0 self.start_step = 0 self.epoch_num = 0 + self.last_save_step = 0 # start at 1 so we can do a sample at the start self.grad_accumulation_step = 1 # if true, then we do not do an optimizer step. We are accumulating gradients @@ -459,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess): step_num = '' if step is not None: + self.last_save_step = step # zeropad 9 digits step_num = f"_{str(step).zfill(9)}" @@ -1827,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd) flush() + self.last_save_step = self.step_num ### HOOK ### self.hook_before_train_loop() diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index c23826ae..1489e26e 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -24,4 +24,5 @@ model Job { stop Boolean @default(false) step Int @default(0) info String @default("") + speed_string String @default("") } diff --git a/ui/src/components/JobOverview.tsx b/ui/src/components/JobOverview.tsx index 278b06a6..ba1ad5c5 100644 --- a/ui/src/components/JobOverview.tsx +++ b/ui/src/components/JobOverview.tsx @@ -68,24 +68,24 @@ export default function JobOverview({ job }: JobOverviewProps) {
{job.name}
Job Name
+{job.name}
GPUs: {job.gpu_ids}
Assigned GPUs
+GPUs: {job.gpu_ids}
{job.info}
Additional Information
+{job.info}