diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py
index dae73d56..7a34337b 100644
--- a/extensions_built_in/sd_trainer/UITrainer.py
+++ b/extensions_built_in/sd_trainer/UITrainer.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict
import os
import sqlite3
import asyncio
@@ -9,9 +10,10 @@ AITK_Status = Literal["running", "stopped", "error", "completed"]
class UITrainer(SDTrainer):
- def __init__(self):
- super(UITrainer, self).__init__()
- self.sqlite_db_path = self.config.get("sqlite_db_path", "data.sqlite")
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
+ super(UITrainer, self).__init__(process_id, job, config, **kwargs)
+ self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
+ print(f"Using SQLite database at {self.sqlite_db_path}")
self.job_id = os.environ.get("AITK_JOB_ID", None)
if self.job_id is None:
raise Exception("AITK_JOB_ID not set")
@@ -19,13 +21,31 @@ class UITrainer(SDTrainer):
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Initialize the status
- asyncio.run(self._update_status("running", "Starting"))
-
+ self._run_async_operation(self._update_status("running", "Starting"))
+
+ def _run_async_operation(self, coro):
+ """Helper method to run an async coroutine in a new event loop."""
+ try:
+ loop = asyncio.get_event_loop()
+ except RuntimeError:
+ # No event loop exists, create a new one
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ # Run the coroutine in the event loop
+ if loop.is_running():
+ # If we're already in an event loop, create a future
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
+ # We could wait for the result if needed: future.result()
+ else:
+ # If no loop is running, run the coroutine and close the loop
+ loop.run_until_complete(coro)
+
async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread to avoid blocking."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, operation_func)
-
+
def _db_connect(self):
"""Create a new connection for each operation to avoid locking."""
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
@@ -36,47 +56,49 @@ class UITrainer(SDTrainer):
def _check_stop():
with self._db_connect() as conn:
cursor = conn.cursor()
- cursor.execute("SELECT stop FROM jobs WHERE job_id = ?", (self.job_id,))
+ cursor.execute(
+ "SELECT stop FROM Job WHERE id = ?", (self.job_id,))
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1
-
+
# For this one we need a synchronous result, so we'll run it directly
return _check_stop()
def maybe_stop(self):
if self.should_stop():
- asyncio.run(self._update_status("stopped", "Job stopped"))
+ self._run_async_operation(
+ self._update_status("stopped", "Job stopped"))
self.is_stopping = True
raise Exception("Job stopped")
async def _update_step(self):
if not self.accelerator.is_main_process:
return
-
+
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
try:
cursor.execute(
- "UPDATE jobs SET step = ? WHERE job_id = ?",
+ "UPDATE Job SET step = ? WHERE id = ?",
(self.step_num, self.job_id)
)
finally:
cursor.execute("COMMIT") # Release the lock
-
+
await self._execute_db_operation(_do_update)
def update_step(self):
"""Non-blocking update of the step count."""
if self.accelerator.is_main_process:
- # Start the async operation without waiting for it
- asyncio.create_task(self._update_step())
+ # Use the helper method to run the async operation
+ self._run_async_operation(self._update_step())
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process:
return
-
+
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
@@ -84,24 +106,24 @@ class UITrainer(SDTrainer):
try:
if info is not None:
cursor.execute(
- "UPDATE jobs SET status = ?, info = ? WHERE job_id = ?",
+ "UPDATE Job SET status = ?, info = ? WHERE id = ?",
(status, info, self.job_id)
)
else:
cursor.execute(
- "UPDATE jobs SET status = ? WHERE job_id = ?",
+ "UPDATE Job SET status = ? WHERE id = ?",
(status, self.job_id)
)
finally:
cursor.execute("COMMIT") # Release the lock
-
+
await self._execute_db_operation(_do_update)
def update_status(self, status: AITK_Status, info: Optional[str] = None):
"""Non-blocking update of status."""
if self.accelerator.is_main_process:
- # Start the async operation without waiting for it
- asyncio.create_task(self._update_status(status, info))
+ # Use the helper method to run the async operation
+ self._run_async_operation(self._update_status(status, info))
def on_error(self, e: Exception):
super(UITrainer, self).on_error(e)
@@ -121,30 +143,36 @@ class UITrainer(SDTrainer):
def hook_before_model_load(self):
super().hook_before_model_load()
+ self.maybe_stop()
self.update_status("running", "Loading model")
def before_dataset_load(self):
super().before_dataset_load()
+ self.maybe_stop()
self.update_status("running", "Loading dataset")
def hook_before_train_loop(self):
super().hook_before_train_loop()
+ self.maybe_stop()
self.update_status("running", "Training")
def sample_step_hook(self, img_num, total_imgs):
super().sample_step_hook(img_num, total_imgs)
- # subtract a since this is called after the image is generated
+ self.maybe_stop()
self.update_status(
- "running", f"Generating images - {img_num - 1} of {total_imgs}")
+ "running", f"Generating images - {img_num + 1}/{total_imgs}")
def sample(self, step=None, is_first=False):
self.maybe_stop()
total_imgs = len(self.sample_config.prompts)
- self.update_status("running", f"Generating images - 1 of {total_imgs}")
+ self.update_status("running", f"Generating images - 0/{total_imgs}")
super().sample(step, is_first)
+ self.maybe_stop()
self.update_status("running", "Training")
-
+
def save(self, step=None):
+ self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
- self.update_status("running", "Training")
\ No newline at end of file
+ self.maybe_stop()
+ self.update_status("running", "Training")
diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py
index a180aec5..150e1e2f 100644
--- a/jobs/process/BaseSDTrainProcess.py
+++ b/jobs/process/BaseSDTrainProcess.py
@@ -1428,7 +1428,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# run base sd process run
self.sd.load_model()
- self.sd.add_after_sample_image_hook(self.after_sample_image_hook)
+ self.sd.add_after_sample_image_hook(self.sample_step_hook)
dtype = get_torch_dtype(self.train_config.dtype)
diff --git a/run.py b/run.py
index ce3553a9..d4ccda2a 100644
--- a/run.py
+++ b/run.py
@@ -88,7 +88,10 @@ def main():
except Exception as e:
print_acc(f"Error running job: {e}")
jobs_failed += 1
- job.process[0].on_error(e)
+ try:
+ job.process[0].on_error(e)
+ except Exception as e2:
+ print_acc(f"Error running on_error: {e2}")
if not args.recover:
print_end_message(jobs_completed, jobs_failed)
raise e
diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma
index 76a56899..b28c43bf 100644
--- a/ui/prisma/schema.prisma
+++ b/ui/prisma/schema.prisma
@@ -4,7 +4,7 @@ generator client {
datasource db {
provider = "sqlite"
- url = "file:../aitk_db.db"
+ url = "file:../../aitk_db.db"
}
model Settings {
@@ -13,9 +13,9 @@ model Settings {
value String
}
-model Training {
+model Job {
id String @id @default(uuid())
- name String
+ name String @unique
gpu_id Int
job_config String // JSON string
created_at DateTime @default(now())
diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts
new file mode 100644
index 00000000..b706043f
--- /dev/null
+++ b/ui/src/app/api/jobs/[jobID]/start/route.ts
@@ -0,0 +1,90 @@
+import { NextRequest, NextResponse } from 'next/server';
+import { PrismaClient } from '@prisma/client';
+import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
+import { spawn } from 'child_process';
+import path from 'path';
+import fs from 'fs';
+
+
+const prisma = new PrismaClient();
+
+export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
+ const { jobID } = await params;
+
+ const job = await prisma.job.findUnique({
+ where: { id: jobID },
+ });
+
+ if (!job) {
+ return NextResponse.json({ error: 'Job not found' }, { status: 404 });
+ }
+
+ // update job status to 'running'
+ await prisma.job.update({
+ where: { id: jobID },
+ data: {
+ status: 'running',
+ stop: false,
+ info: 'Starting job...',
+ },
+ });
+
+ // setup the training
+ const settings = await prisma.settings.findMany();
+ const settingsObject = settings.reduce((acc: any, setting) => {
+ acc[setting.key] = setting.value;
+ return acc;
+ }, {});
+
+ // if TRAINING_FOLDER is not set, use default
+ if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
+ settingsObject.TRAINING_FOLDER = defaultTrainFolder;
+ }
+
+ const trainingFolder = path.join(settingsObject.TRAINING_FOLDER, job.name);
+ if (!fs.existsSync(trainingFolder)) {
+ fs.mkdirSync(trainingFolder, { recursive: true });
+ }
+
+ // make the config file
+ const configPath = path.join(trainingFolder, '.job_config.json');
+
+ // update the config dataset path
+ const jobConfig = JSON.parse(job.job_config);
+ jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
+
+
+ // write the config file
+ fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
+
+ let pythonPath = 'python';
+ // use .venv or venv if it exists
+ if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
+ pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
+ } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
+ pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
+ }
+
+ const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
+ if (!fs.existsSync(runFilePath)) {
+ return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
+ }
+
+ console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_id} ${pythonPath} ${runFilePath} ${configPath}`);
+
+ // start job
+ const subprocess = spawn(pythonPath, [runFilePath, configPath], {
+ detached: true,
+ stdio: 'ignore',
+ env: {
+ ...process.env,
+ AITK_JOB_ID: jobID,
+ CUDA_VISIBLE_DEVICES: `${job.gpu_id}`,
+ },
+ cwd: TOOLKIT_ROOT,
+ });
+
+ subprocess.unref();
+
+ return NextResponse.json(job);
+}
diff --git a/ui/src/app/api/jobs/[jobID]/stop/route.ts b/ui/src/app/api/jobs/[jobID]/stop/route.ts
new file mode 100644
index 00000000..73b352df
--- /dev/null
+++ b/ui/src/app/api/jobs/[jobID]/stop/route.ts
@@ -0,0 +1,23 @@
+import { NextRequest, NextResponse } from 'next/server';
+import { PrismaClient } from '@prisma/client';
+
+const prisma = new PrismaClient();
+
+export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
+ const { jobID } = await params;
+
+ const job = await prisma.job.findUnique({
+ where: { id: jobID },
+ });
+
+ // update job status to 'running'
+ await prisma.job.update({
+ where: { id: jobID },
+ data: {
+ stop: true,
+ info: 'Stopping job...',
+ },
+ });
+
+ return NextResponse.json(job);
+}
diff --git a/ui/src/app/api/training/route.ts b/ui/src/app/api/jobs/route.ts
similarity index 84%
rename from ui/src/app/api/training/route.ts
rename to ui/src/app/api/jobs/route.ts
index 64d3d73a..3e0df9f0 100644
--- a/ui/src/app/api/training/route.ts
+++ b/ui/src/app/api/jobs/route.ts
@@ -9,17 +9,18 @@ export async function GET(request: Request) {
try {
if (id) {
- const training = await prisma.training.findUnique({
+ const training = await prisma.job.findUnique({
where: { id },
});
return NextResponse.json(training);
}
- const trainings = await prisma.training.findMany({
+ const trainings = await prisma.job.findMany({
orderBy: { created_at: 'desc' },
});
return NextResponse.json(trainings);
} catch (error) {
+ console.error(error);
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
}
}
@@ -31,7 +32,7 @@ export async function POST(request: Request) {
if (id) {
// Update existing training
- const training = await prisma.training.update({
+ const training = await prisma.job.update({
where: { id },
data: {
name,
@@ -42,7 +43,7 @@ export async function POST(request: Request) {
return NextResponse.json(training);
} else {
// Create new training
- const training = await prisma.training.create({
+ const training = await prisma.job.create({
data: {
name,
gpu_id,
diff --git a/ui/src/app/jobs/[jobID]/page.tsx b/ui/src/app/jobs/[jobID]/page.tsx
new file mode 100644
index 00000000..4929804f
--- /dev/null
+++ b/ui/src/app/jobs/[jobID]/page.tsx
@@ -0,0 +1,72 @@
+'use client';
+
+import { useEffect, useState, use } from 'react';
+import { FaChevronLeft } from 'react-icons/fa';
+import { Button } from '@headlessui/react';
+import { TopBar, MainContent } from '@/components/layout';
+import useJob from '@/hooks/useJob';
+import { startJob, stopJob } from '@/utils/jobs';
+
+export default function JobPage({ params }: { params: { jobID: string } }) {
+ const usableParams = use(params as any) as { jobID: string };
+ const jobID = usableParams.jobID;
+ const { job, status, refreshJobs } = useJob(jobID, 5000);
+
+ return (
+ <>
+ {/* Fixed top bar */}
+ Loading... Error fetching job ID: {job.id} Name: {job.name} GPU: {job.gpu_id} Status: {job.status} Info: {job.info} Step: {job.step}Job: {job?.name}
+ Job Details
+ {runId ? 'Edit Training Run' : 'New Training Run'}
+
+ {runId ? 'Edit Training Job' : 'New Training Job'}
Training Jobs
+
No jobs available
+ +| Name | +Steps | +GPU | +Status | +Info | +
|---|---|---|---|---|
| + {job.name} | +
+
+
+ {job.step} / {totalSteps}
+
+
+
+
+
+ |
+ {job.gpu_id} | +{job.status} | +{job.info} | +