mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Set step to the last step saved at when exiting
This commit is contained in:
@@ -20,11 +20,13 @@ class UITrainer(SDTrainer):
|
||||
self.is_stopping = False
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Track all async tasks
|
||||
self._async_tasks = []
|
||||
# Initialize the status
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine in a new event loop."""
|
||||
"""Helper method to run an async coroutine and track the task."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
@@ -32,14 +34,14 @@ class UITrainer(SDTrainer):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
# Create a task and track it
|
||||
if loop.is_running():
|
||||
# If we're already in an event loop, create a future
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
# We could wait for the result if needed: future.result()
|
||||
task = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
self._async_tasks.append(asyncio.wrap_future(task))
|
||||
else:
|
||||
# If no loop is running, run the coroutine and close the loop
|
||||
loop.run_until_complete(coro)
|
||||
task = loop.create_task(coro)
|
||||
self._async_tasks.append(task)
|
||||
loop.run_until_complete(task)
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
@@ -61,7 +63,6 @@ class UITrainer(SDTrainer):
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
# For this one we need a synchronous result, so we'll run it directly
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
@@ -71,29 +72,36 @@ class UITrainer(SDTrainer):
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_step(self):
|
||||
async def _update_key(self, key, value):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
cursor.execute(
|
||||
"UPDATE Job SET step = ? WHERE id = ?",
|
||||
(self.step_num, self.job_id)
|
||||
f"UPDATE Job SET {key} = ? WHERE id = ?",
|
||||
(value, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_step())
|
||||
self._run_async_operation(self._update_key("step", self.step_num))
|
||||
|
||||
|
||||
def update_db_key(self, key, value):
|
||||
"""Non-blocking update a key in the database."""
|
||||
if self.accelerator.is_main_process:
|
||||
self._run_async_operation(self._update_key(key, value))
|
||||
|
||||
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
@@ -102,7 +110,7 @@ class UITrainer(SDTrainer):
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
cursor.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
@@ -115,26 +123,40 @@ class UITrainer(SDTrainer):
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
cursor.execute("COMMIT")
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_status(status, info))
|
||||
|
||||
async def wait_for_all_async(self):
|
||||
"""Wait for all tracked async operations to complete."""
|
||||
if not self._async_tasks:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._async_tasks)
|
||||
finally:
|
||||
# Clear the task list after completion
|
||||
self._async_tasks.clear()
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(UITrainer, self).on_error(e)
|
||||
if self.accelerator.is_main_process and not self.is_stopping:
|
||||
self.update_status("error", str(e))
|
||||
self.update_db_key("step", self.last_save_step)
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def done_hook(self):
|
||||
super(UITrainer, self).done_hook()
|
||||
self.update_status("completed", "Training completed")
|
||||
# Make sure we clean up the thread pool
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
# Wait for all async operations to finish before shutting down
|
||||
asyncio.run(self.wait_for_all_async())
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
|
||||
def end_step_hook(self):
|
||||
super(UITrainer, self).end_step_hook()
|
||||
@@ -154,16 +176,16 @@ class UITrainer(SDTrainer):
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
self.maybe_stop()
|
||||
self.update_step()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
|
||||
def status_update_hook_func(self, string):
|
||||
self.update_status("running", string)
|
||||
|
||||
|
||||
def hook_after_sd_init_before_load(self):
|
||||
super().hook_after_sd_init_before_load()
|
||||
self.maybe_stop()
|
||||
self.sd.add_status_update_hook(self.status_update_hook_func)
|
||||
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
@@ -184,4 +206,4 @@ class UITrainer(SDTrainer):
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
self.update_status("running", "Training")
|
||||
@@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.epoch_num = 0
|
||||
self.last_save_step = 0
|
||||
# start at 1 so we can do a sample at the start
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
@@ -459,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
self.last_save_step = step
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
@@ -1827,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd)
|
||||
|
||||
flush()
|
||||
self.last_save_step = self.step_num
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
|
||||
@@ -24,4 +24,5 @@ model Job {
|
||||
stop Boolean @default(false)
|
||||
step Int @default(0)
|
||||
info String @default("")
|
||||
speed_string String @default("")
|
||||
}
|
||||
|
||||
@@ -68,24 +68,24 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
<div className="flex items-center space-x-4">
|
||||
<HardDrive className="w-5 h-5 text-blue-400" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-gray-200">{job.name}</p>
|
||||
<p className="text-xs text-gray-400">Job Name</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.name}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-4">
|
||||
<Cpu className="w-5 h-5 text-purple-400" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
|
||||
<p className="text-xs text-gray-400">Assigned GPUs</p>
|
||||
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-4">
|
||||
<Info className="w-5 h-5 text-amber-400" />
|
||||
<div>
|
||||
<p className="text-sm font-medium text-gray-200">{job.info}</p>
|
||||
<p className="text-xs text-gray-400">Additional Information</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.info}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -95,7 +95,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
{/* GPU Widget Panel */}
|
||||
<div className="col-span-1">
|
||||
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
||||
<div className='mt-4'>
|
||||
<div className="mt-4">
|
||||
<FilesWidget jobID={job.id} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user