Set step to the last step saved at when exiting

This commit is contained in:
Jaret Burkett
2025-02-23 13:21:22 -07:00
parent 60f848a877
commit 3e49337a58
4 changed files with 54 additions and 28 deletions

View File

@@ -20,11 +20,13 @@ class UITrainer(SDTrainer):
self.is_stopping = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
self._async_tasks = []
# Initialize the status
self._run_async_operation(self._update_status("running", "Starting"))
def _run_async_operation(self, coro):
"""Helper method to run an async coroutine in a new event loop."""
"""Helper method to run an async coroutine and track the task."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
@@ -32,14 +34,14 @@ class UITrainer(SDTrainer):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the coroutine in the event loop
# Create a task and track it
if loop.is_running():
# If we're already in an event loop, create a future
future = asyncio.run_coroutine_threadsafe(coro, loop)
# We could wait for the result if needed: future.result()
task = asyncio.run_coroutine_threadsafe(coro, loop)
self._async_tasks.append(asyncio.wrap_future(task))
else:
# If no loop is running, run the coroutine and close the loop
loop.run_until_complete(coro)
task = loop.create_task(coro)
self._async_tasks.append(task)
loop.run_until_complete(task)
async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread to avoid blocking."""
@@ -61,7 +63,6 @@ class UITrainer(SDTrainer):
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1
# For this one we need a synchronous result, so we'll run it directly
return _check_stop()
def maybe_stop(self):
@@ -71,29 +72,36 @@ class UITrainer(SDTrainer):
self.is_stopping = True
raise Exception("Job stopped")
async def _update_step(self):
async def _update_key(self, key, value):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
cursor.execute("BEGIN IMMEDIATE")
try:
cursor.execute(
"UPDATE Job SET step = ? WHERE id = ?",
(self.step_num, self.job_id)
f"UPDATE Job SET {key} = ? WHERE id = ?",
(value, self.job_id)
)
finally:
cursor.execute("COMMIT") # Release the lock
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_step(self):
"""Non-blocking update of the step count."""
if self.accelerator.is_main_process:
# Use the helper method to run the async operation
self._run_async_operation(self._update_step())
self._run_async_operation(self._update_key("step", self.step_num))
def update_db_key(self, key, value):
"""Non-blocking update a key in the database."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key(key, value))
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process:
@@ -102,7 +110,7 @@ class UITrainer(SDTrainer):
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
cursor.execute("BEGIN IMMEDIATE")
try:
if info is not None:
cursor.execute(
@@ -115,26 +123,40 @@ class UITrainer(SDTrainer):
(status, self.job_id)
)
finally:
cursor.execute("COMMIT") # Release the lock
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_status(self, status: AITK_Status, info: Optional[str] = None):
"""Non-blocking update of status."""
if self.accelerator.is_main_process:
# Use the helper method to run the async operation
self._run_async_operation(self._update_status(status, info))
async def wait_for_all_async(self):
"""Wait for all tracked async operations to complete."""
if not self._async_tasks:
return
try:
await asyncio.gather(*self._async_tasks)
finally:
# Clear the task list after completion
self._async_tasks.clear()
def on_error(self, e: Exception):
super(UITrainer, self).on_error(e)
if self.accelerator.is_main_process and not self.is_stopping:
self.update_status("error", str(e))
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def done_hook(self):
super(UITrainer, self).done_hook()
self.update_status("completed", "Training completed")
# Make sure we clean up the thread pool
self.thread_pool.shutdown(wait=False)
# Wait for all async operations to finish before shutting down
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def end_step_hook(self):
super(UITrainer, self).end_step_hook()
@@ -154,16 +176,16 @@ class UITrainer(SDTrainer):
def hook_before_train_loop(self):
super().hook_before_train_loop()
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
def status_update_hook_func(self, string):
self.update_status("running", string)
def hook_after_sd_init_before_load(self):
super().hook_after_sd_init_before_load()
self.maybe_stop()
self.sd.add_status_update_hook(self.status_update_hook_func)
def sample_step_hook(self, img_num, total_imgs):
super().sample_step_hook(img_num, total_imgs)
@@ -184,4 +206,4 @@ class UITrainer(SDTrainer):
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")
self.update_status("running", "Training")

View File

@@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.step_num = 0
self.start_step = 0
self.epoch_num = 0
self.last_save_step = 0
# start at 1 so we can do a sample at the start
self.grad_accumulation_step = 1
# if true, then we do not do an optimizer step. We are accumulating gradients
@@ -459,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
step_num = ''
if step is not None:
self.last_save_step = step
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
@@ -1827,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd)
flush()
self.last_save_step = self.step_num
### HOOK ###
self.hook_before_train_loop()

View File

@@ -24,4 +24,5 @@ model Job {
stop Boolean @default(false)
step Int @default(0)
info String @default("")
speed_string String @default("")
}

View File

@@ -68,24 +68,24 @@ export default function JobOverview({ job }: JobOverviewProps) {
<div className="flex items-center space-x-4">
<HardDrive className="w-5 h-5 text-blue-400" />
<div>
<p className="text-sm font-medium text-gray-200">{job.name}</p>
<p className="text-xs text-gray-400">Job Name</p>
<p className="text-sm font-medium text-gray-200">{job.name}</p>
</div>
</div>
<div className="flex items-center space-x-4">
<Cpu className="w-5 h-5 text-purple-400" />
<div>
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
<p className="text-xs text-gray-400">Assigned GPUs</p>
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
</div>
</div>
<div className="flex items-center space-x-4">
<Info className="w-5 h-5 text-amber-400" />
<div>
<p className="text-sm font-medium text-gray-200">{job.info}</p>
<p className="text-xs text-gray-400">Additional Information</p>
<p className="text-sm font-medium text-gray-200">{job.info}</p>
</div>
</div>
</div>
@@ -95,7 +95,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
{/* GPU Widget Panel */}
<div className="col-span-1">
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
<div className='mt-4'>
<div className="mt-4">
<FilesWidget jobID={job.id} />
</div>
</div>