mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Set step to the last step saved at when exiting
This commit is contained in:
@@ -20,11 +20,13 @@ class UITrainer(SDTrainer):
|
|||||||
self.is_stopping = False
|
self.is_stopping = False
|
||||||
# Create a thread pool for database operations
|
# Create a thread pool for database operations
|
||||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||||
|
# Track all async tasks
|
||||||
|
self._async_tasks = []
|
||||||
# Initialize the status
|
# Initialize the status
|
||||||
self._run_async_operation(self._update_status("running", "Starting"))
|
self._run_async_operation(self._update_status("running", "Starting"))
|
||||||
|
|
||||||
def _run_async_operation(self, coro):
|
def _run_async_operation(self, coro):
|
||||||
"""Helper method to run an async coroutine in a new event loop."""
|
"""Helper method to run an async coroutine and track the task."""
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@@ -32,14 +34,14 @@ class UITrainer(SDTrainer):
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# Run the coroutine in the event loop
|
# Create a task and track it
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
# If we're already in an event loop, create a future
|
task = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
self._async_tasks.append(asyncio.wrap_future(task))
|
||||||
# We could wait for the result if needed: future.result()
|
|
||||||
else:
|
else:
|
||||||
# If no loop is running, run the coroutine and close the loop
|
task = loop.create_task(coro)
|
||||||
loop.run_until_complete(coro)
|
self._async_tasks.append(task)
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
|
||||||
async def _execute_db_operation(self, operation_func):
|
async def _execute_db_operation(self, operation_func):
|
||||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||||
@@ -61,7 +63,6 @@ class UITrainer(SDTrainer):
|
|||||||
stop = cursor.fetchone()
|
stop = cursor.fetchone()
|
||||||
return False if stop is None else stop[0] == 1
|
return False if stop is None else stop[0] == 1
|
||||||
|
|
||||||
# For this one we need a synchronous result, so we'll run it directly
|
|
||||||
return _check_stop()
|
return _check_stop()
|
||||||
|
|
||||||
def maybe_stop(self):
|
def maybe_stop(self):
|
||||||
@@ -71,29 +72,36 @@ class UITrainer(SDTrainer):
|
|||||||
self.is_stopping = True
|
self.is_stopping = True
|
||||||
raise Exception("Job stopped")
|
raise Exception("Job stopped")
|
||||||
|
|
||||||
async def _update_step(self):
|
async def _update_key(self, key, value):
|
||||||
if not self.accelerator.is_main_process:
|
if not self.accelerator.is_main_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
def _do_update():
|
def _do_update():
|
||||||
with self._db_connect() as conn:
|
with self._db_connect() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
cursor.execute("BEGIN IMMEDIATE")
|
||||||
try:
|
try:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"UPDATE Job SET step = ? WHERE id = ?",
|
f"UPDATE Job SET {key} = ? WHERE id = ?",
|
||||||
(self.step_num, self.job_id)
|
(value, self.job_id)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
cursor.execute("COMMIT") # Release the lock
|
cursor.execute("COMMIT")
|
||||||
|
|
||||||
await self._execute_db_operation(_do_update)
|
await self._execute_db_operation(_do_update)
|
||||||
|
|
||||||
def update_step(self):
|
def update_step(self):
|
||||||
"""Non-blocking update of the step count."""
|
"""Non-blocking update of the step count."""
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
# Use the helper method to run the async operation
|
self._run_async_operation(self._update_key("step", self.step_num))
|
||||||
self._run_async_operation(self._update_step())
|
|
||||||
|
|
||||||
|
def update_db_key(self, key, value):
|
||||||
|
"""Non-blocking update a key in the database."""
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
self._run_async_operation(self._update_key(key, value))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||||
if not self.accelerator.is_main_process:
|
if not self.accelerator.is_main_process:
|
||||||
@@ -102,7 +110,7 @@ class UITrainer(SDTrainer):
|
|||||||
def _do_update():
|
def _do_update():
|
||||||
with self._db_connect() as conn:
|
with self._db_connect() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
cursor.execute("BEGIN IMMEDIATE")
|
||||||
try:
|
try:
|
||||||
if info is not None:
|
if info is not None:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -115,26 +123,40 @@ class UITrainer(SDTrainer):
|
|||||||
(status, self.job_id)
|
(status, self.job_id)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
cursor.execute("COMMIT") # Release the lock
|
cursor.execute("COMMIT")
|
||||||
|
|
||||||
await self._execute_db_operation(_do_update)
|
await self._execute_db_operation(_do_update)
|
||||||
|
|
||||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||||
"""Non-blocking update of status."""
|
"""Non-blocking update of status."""
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
# Use the helper method to run the async operation
|
|
||||||
self._run_async_operation(self._update_status(status, info))
|
self._run_async_operation(self._update_status(status, info))
|
||||||
|
|
||||||
|
async def wait_for_all_async(self):
|
||||||
|
"""Wait for all tracked async operations to complete."""
|
||||||
|
if not self._async_tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*self._async_tasks)
|
||||||
|
finally:
|
||||||
|
# Clear the task list after completion
|
||||||
|
self._async_tasks.clear()
|
||||||
|
|
||||||
def on_error(self, e: Exception):
|
def on_error(self, e: Exception):
|
||||||
super(UITrainer, self).on_error(e)
|
super(UITrainer, self).on_error(e)
|
||||||
if self.accelerator.is_main_process and not self.is_stopping:
|
if self.accelerator.is_main_process and not self.is_stopping:
|
||||||
self.update_status("error", str(e))
|
self.update_status("error", str(e))
|
||||||
|
self.update_db_key("step", self.last_save_step)
|
||||||
|
asyncio.run(self.wait_for_all_async())
|
||||||
|
self.thread_pool.shutdown(wait=True)
|
||||||
|
|
||||||
def done_hook(self):
|
def done_hook(self):
|
||||||
super(UITrainer, self).done_hook()
|
super(UITrainer, self).done_hook()
|
||||||
self.update_status("completed", "Training completed")
|
self.update_status("completed", "Training completed")
|
||||||
# Make sure we clean up the thread pool
|
# Wait for all async operations to finish before shutting down
|
||||||
self.thread_pool.shutdown(wait=False)
|
asyncio.run(self.wait_for_all_async())
|
||||||
|
self.thread_pool.shutdown(wait=True)
|
||||||
|
|
||||||
def end_step_hook(self):
|
def end_step_hook(self):
|
||||||
super(UITrainer, self).end_step_hook()
|
super(UITrainer, self).end_step_hook()
|
||||||
@@ -154,16 +176,16 @@ class UITrainer(SDTrainer):
|
|||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
super().hook_before_train_loop()
|
super().hook_before_train_loop()
|
||||||
self.maybe_stop()
|
self.maybe_stop()
|
||||||
|
self.update_step()
|
||||||
self.update_status("running", "Training")
|
self.update_status("running", "Training")
|
||||||
|
|
||||||
def status_update_hook_func(self, string):
|
def status_update_hook_func(self, string):
|
||||||
self.update_status("running", string)
|
self.update_status("running", string)
|
||||||
|
|
||||||
def hook_after_sd_init_before_load(self):
|
def hook_after_sd_init_before_load(self):
|
||||||
super().hook_after_sd_init_before_load()
|
super().hook_after_sd_init_before_load()
|
||||||
self.maybe_stop()
|
self.maybe_stop()
|
||||||
self.sd.add_status_update_hook(self.status_update_hook_func)
|
self.sd.add_status_update_hook(self.status_update_hook_func)
|
||||||
|
|
||||||
|
|
||||||
def sample_step_hook(self, img_num, total_imgs):
|
def sample_step_hook(self, img_num, total_imgs):
|
||||||
super().sample_step_hook(img_num, total_imgs)
|
super().sample_step_hook(img_num, total_imgs)
|
||||||
@@ -184,4 +206,4 @@ class UITrainer(SDTrainer):
|
|||||||
self.update_status("running", "Saving model")
|
self.update_status("running", "Saving model")
|
||||||
super().save(step)
|
super().save(step)
|
||||||
self.maybe_stop()
|
self.maybe_stop()
|
||||||
self.update_status("running", "Training")
|
self.update_status("running", "Training")
|
||||||
@@ -92,6 +92,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.step_num = 0
|
self.step_num = 0
|
||||||
self.start_step = 0
|
self.start_step = 0
|
||||||
self.epoch_num = 0
|
self.epoch_num = 0
|
||||||
|
self.last_save_step = 0
|
||||||
# start at 1 so we can do a sample at the start
|
# start at 1 so we can do a sample at the start
|
||||||
self.grad_accumulation_step = 1
|
self.grad_accumulation_step = 1
|
||||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||||
@@ -459,6 +460,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
step_num = ''
|
step_num = ''
|
||||||
if step is not None:
|
if step is not None:
|
||||||
|
self.last_save_step = step
|
||||||
# zeropad 9 digits
|
# zeropad 9 digits
|
||||||
step_num = f"_{str(step).zfill(9)}"
|
step_num = f"_{str(step).zfill(9)}"
|
||||||
|
|
||||||
@@ -1827,6 +1829,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.sd)
|
self.sd)
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
|
self.last_save_step = self.step_num
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
self.hook_before_train_loop()
|
self.hook_before_train_loop()
|
||||||
|
|
||||||
|
|||||||
@@ -24,4 +24,5 @@ model Job {
|
|||||||
stop Boolean @default(false)
|
stop Boolean @default(false)
|
||||||
step Int @default(0)
|
step Int @default(0)
|
||||||
info String @default("")
|
info String @default("")
|
||||||
|
speed_string String @default("")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,24 +68,24 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
|||||||
<div className="flex items-center space-x-4">
|
<div className="flex items-center space-x-4">
|
||||||
<HardDrive className="w-5 h-5 text-blue-400" />
|
<HardDrive className="w-5 h-5 text-blue-400" />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-sm font-medium text-gray-200">{job.name}</p>
|
|
||||||
<p className="text-xs text-gray-400">Job Name</p>
|
<p className="text-xs text-gray-400">Job Name</p>
|
||||||
|
<p className="text-sm font-medium text-gray-200">{job.name}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex items-center space-x-4">
|
<div className="flex items-center space-x-4">
|
||||||
<Cpu className="w-5 h-5 text-purple-400" />
|
<Cpu className="w-5 h-5 text-purple-400" />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
|
|
||||||
<p className="text-xs text-gray-400">Assigned GPUs</p>
|
<p className="text-xs text-gray-400">Assigned GPUs</p>
|
||||||
|
<p className="text-sm font-medium text-gray-200">GPUs: {job.gpu_ids}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex items-center space-x-4">
|
<div className="flex items-center space-x-4">
|
||||||
<Info className="w-5 h-5 text-amber-400" />
|
<Info className="w-5 h-5 text-amber-400" />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-sm font-medium text-gray-200">{job.info}</p>
|
|
||||||
<p className="text-xs text-gray-400">Additional Information</p>
|
<p className="text-xs text-gray-400">Additional Information</p>
|
||||||
|
<p className="text-sm font-medium text-gray-200">{job.info}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -95,7 +95,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
|||||||
{/* GPU Widget Panel */}
|
{/* GPU Widget Panel */}
|
||||||
<div className="col-span-1">
|
<div className="col-span-1">
|
||||||
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
||||||
<div className='mt-4'>
|
<div className="mt-4">
|
||||||
<FilesWidget jobID={job.id} />
|
<FilesWidget jobID={job.id} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user