diff --git a/extensions_built_in/sd_trainer/DiffusionTrainer.py b/extensions_built_in/sd_trainer/DiffusionTrainer.py index 730a1aee..f39611b1 100644 --- a/extensions_built_in/sd_trainer/DiffusionTrainer.py +++ b/extensions_built_in/sd_trainer/DiffusionTrainer.py @@ -124,6 +124,19 @@ class DiffusionTrainer(SDTrainer): return _check_stop() + def should_return_to_queue(self): + if not self.is_ui_trainer: + return False + def _check_return_to_queue(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,)) + return_to_queue = cursor.fetchone() + return False if return_to_queue is None else return_to_queue[0] == 1 + + return _check_return_to_queue() + def maybe_stop(self): if not self.is_ui_trainer: return @@ -132,6 +145,11 @@ class DiffusionTrainer(SDTrainer): self._update_status("stopped", "Job stopped")) self.is_stopping = True raise Exception("Job stopped") + if self.should_return_to_queue(): + self._run_async_operation( + self._update_status("queued", "Job queued")) + self.is_stopping = True + raise Exception("Job returning to queue") async def _update_key(self, key, value): if not self.accelerator.is_main_process: diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index 96a162e4..8b5fa796 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -115,6 +115,17 @@ class UITrainer(SDTrainer): return False if stop is None else stop[0] == 1 return _check_stop() + + def should_return_to_queue(self): + def _check_return_to_queue(): + with self._db_connect() as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT return_to_queue FROM Job WHERE id = ?", (self.job_id,)) + return_to_queue = cursor.fetchone() + return False if return_to_queue is None else return_to_queue[0] == 1 + + return _check_return_to_queue() def maybe_stop(self): if self.should_stop(): @@ -122,6 +133,11 @@ class UITrainer(SDTrainer): self._update_status("stopped", "Job stopped")) self.is_stopping = True raise Exception("Job stopped") + if self.should_return_to_queue(): + self._run_async_operation( + self._update_status("queued", "Job queued")) + self.is_stopping = True + raise Exception("Job returning to queue") async def _update_key(self, key, value): if not self.accelerator.is_main_process: diff --git a/ui/cron/actions/processQueue.ts b/ui/cron/actions/processQueue.ts new file mode 100644 index 00000000..175f613e --- /dev/null +++ b/ui/cron/actions/processQueue.ts @@ -0,0 +1,71 @@ +import prisma from '../prisma'; + +import { Job, Queue } from '@prisma/client'; +import startJob from './startJob'; + +export default async function processQueue() { + const queues: Queue[] = await prisma.queue.findMany({ + orderBy: { + id: 'asc', + }, + }); + + for (const queue of queues) { + if (!queue.is_running) { + // stop any running jobs first + const runningJobs: Job[] = await prisma.job.findMany({ + where: { + status: 'running', + gpu_ids: queue.gpu_ids, + }, + }); + + for (const job of runningJobs) { + console.log(`Stopping job ${job.id} on GPU(s) ${job.gpu_ids}`); + await prisma.job.update({ + where: { id: job.id }, + data: { + return_to_queue: true, + info: 'Stopping job...', + }, + }); + } + } + if (queue.is_running) { + // first see if one is already running, status of running or stopping + const runningJob: Job | null = await prisma.job.findFirst({ + where: { + status: { in: ['running', 'stopping'] }, + gpu_ids: queue.gpu_ids, + }, + }); + + if (runningJob) { + // already running, nothing to do + continue; // skip to next queue + } else { + // find the next job in the queue + const nextJob: Job | null = await prisma.job.findFirst({ + where: { + status: 'queued', + gpu_ids: queue.gpu_ids, + }, + orderBy: { + queue_position: 'asc', + }, + }); + if (nextJob) { + console.log(`Starting job ${nextJob.id} on GPU(s) ${nextJob.gpu_ids}`); + await startJob(nextJob.id); + } else { + // no more jobs, stop the queue + console.log(`No more jobs in queue for GPU(s) ${queue.gpu_ids}, stopping queue`); + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: false }, + }); + } + } + } + } +} diff --git a/ui/cron/actions/startJob.ts b/ui/cron/actions/startJob.ts new file mode 100644 index 00000000..3a609a30 --- /dev/null +++ b/ui/cron/actions/startJob.ts @@ -0,0 +1,179 @@ +import prisma from '../prisma'; +import { Job } from '@prisma/client'; +import { spawn } from 'child_process'; +import path from 'path'; +import fs from 'fs'; +import { TOOLKIT_ROOT, getTrainingFolder, getHFToken } from '../paths'; +const isWindows = process.platform === 'win32'; + +const startAndWatchJob = (job: Job) => { + // starts and watches the job asynchronously + return new Promise(async (resolve, reject) => { + const jobID = job.id; + + // setup the training + const trainingRoot = await getTrainingFolder(); + + const trainingFolder = path.join(trainingRoot, job.name); + if (!fs.existsSync(trainingFolder)) { + fs.mkdirSync(trainingFolder, { recursive: true }); + } + + // make the config file + const configPath = path.join(trainingFolder, '.job_config.json'); + + //log to path + const logPath = path.join(trainingFolder, 'log.txt'); + + try { + // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num + // if the log path does not exist, create it + if (fs.existsSync(logPath)) { + const logsFolder = path.join(trainingFolder, 'logs'); + if (!fs.existsSync(logsFolder)) { + fs.mkdirSync(logsFolder, { recursive: true }); + } + + let num = 0; + while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) { + num++; + } + + fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`)); + } + } catch (e) { + console.error('Error moving log file:', e); + } + + // update the config dataset path + const jobConfig = JSON.parse(job.job_config); + jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db'); + + // write the config file + fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2)); + + let pythonPath = 'python'; + // use .venv or venv if it exists + if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) { + if (isWindows) { + pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe'); + } else { + pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python'); + } + } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) { + if (isWindows) { + pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe'); + } else { + pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python'); + } + } + + const runFilePath = path.join(TOOLKIT_ROOT, 'run.py'); + if (!fs.existsSync(runFilePath)) { + console.error(`run.py not found at path: ${runFilePath}`); + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: run.py not found`, + }, + }); + return; + } + + const additionalEnv: any = { + AITK_JOB_ID: jobID, + CUDA_DEVICE_ORDER: 'PCI_BUS_ID', + CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, + IS_AI_TOOLKIT_UI: '1', + }; + + // HF_TOKEN + const hfToken = await getHFToken(); + if (hfToken && hfToken.trim() !== '') { + additionalEnv.HF_TOKEN = hfToken; + } + + // Add the --log argument to the command + const args = [runFilePath, configPath, '--log', logPath]; + + try { + let subprocess; + + if (isWindows) { + // Spawn Python directly on Windows so the process can survive parent exit + subprocess = spawn(pythonPath, args, { + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + detached: true, + windowsHide: true, + stdio: 'ignore', // don't tie stdio to parent + }); + } else { + // For non-Windows platforms, fully detach and ignore stdio so it survives daemon-like + subprocess = spawn(pythonPath, args, { + detached: true, + stdio: 'ignore', + env: { + ...process.env, + ...additionalEnv, + }, + cwd: TOOLKIT_ROOT, + }); + } + + // Important: let the child run independently of this Node process. + if (subprocess.unref) { + subprocess.unref(); + } + + // Optionally write a pid file for future management (stop/inspect) without keeping streams open + try { + fs.writeFileSync(path.join(trainingFolder, 'pid.txt'), String(subprocess.pid ?? ''), { flag: 'w' }); + } catch (e) { + console.error('Error writing pid file:', e); + } + + // (No stdout/stderr listeners — logging should go to --log handled by your Python) + // (No monitoring loop — the whole point is to let it live past this worker) + } catch (error: any) { + // Handle any exceptions during process launch + console.error('Error launching process:', error); + + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'error', + info: `Error launching job: ${error?.message || 'Unknown error'}`, + }, + }); + return; + } + // Resolve the promise immediately after starting the process + resolve(); + }); +}; + +export default async function startJob(jobID: string) { + const job: Job | null = await prisma.job.findUnique({ + where: { id: jobID }, + }); + if (!job) { + console.error(`Job with ID ${jobID} not found`); + return; + } + // update job status to 'running', this will run sync so we don't start multiple jobs. + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'running', + stop: false, + info: 'Starting job...', + }, + }); + // start and watch the job asynchronously so the cron can continue + startAndWatchJob(job); +} diff --git a/ui/cron/paths.ts b/ui/cron/paths.ts new file mode 100644 index 00000000..ef28b973 --- /dev/null +++ b/ui/cron/paths.ts @@ -0,0 +1,37 @@ +import path from 'path'; +import prisma from './prisma'; + +export const TOOLKIT_ROOT = path.resolve('@', '..', '..'); +export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output'); +export const defaultDatasetsFolder = path.join(TOOLKIT_ROOT, 'datasets'); +export const defaultDataRoot = path.join(TOOLKIT_ROOT, 'data'); + +console.log('TOOLKIT_ROOT:', TOOLKIT_ROOT); + +export const getTrainingFolder = async () => { + const key = 'TRAINING_FOLDER'; + let row = await prisma.settings.findFirst({ + where: { + key: key, + }, + }); + let trainingRoot = defaultTrainFolder; + if (row?.value && row.value !== '') { + trainingRoot = row.value; + } + return trainingRoot as string; +}; + +export const getHFToken = async () => { + const key = 'HF_TOKEN'; + let row = await prisma.settings.findFirst({ + where: { + key: key, + }, + }); + let token = ''; + if (row?.value && row.value !== '') { + token = row.value; + } + return token; +}; diff --git a/ui/cron/prisma.ts b/ui/cron/prisma.ts new file mode 100644 index 00000000..56d96d4d --- /dev/null +++ b/ui/cron/prisma.ts @@ -0,0 +1,4 @@ +import { PrismaClient } from '@prisma/client'; +const prisma = new PrismaClient(); + +export default prisma; diff --git a/ui/cron/worker.ts b/ui/cron/worker.ts index 589393a4..dd1c275d 100644 --- a/ui/cron/worker.ts +++ b/ui/cron/worker.ts @@ -1,3 +1,4 @@ +import processQueue from './actions/processQueue'; class CronWorker { interval: number; is_running: boolean; @@ -23,7 +24,9 @@ class CronWorker { this.is_running = false; } - async loop() {} + async loop() { + await processQueue(); + } } // it automatically starts the loop diff --git a/ui/package.json b/ui/package.json index 683e8cbc..1a5cdfdb 100644 --- a/ui/package.json +++ b/ui/package.json @@ -3,9 +3,9 @@ "version": "0.1.0", "private": true, "scripts": { - "dev": "concurrently -k -n WORKER,UI \"ts-node-dev --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"", + "dev": "concurrently -k -n WORKER,UI \"ts-node-dev --project tsconfig.worker.json --respawn --watch cron --transpile-only cron/worker.ts\" \"next dev --turbopack\"", "build": "tsc -p tsconfig.worker.json && next build", - "start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/worker.js\" \"next start --port 8675\"", + "start": "concurrently --restart-tries -1 --restart-after 1000 -n WORKER,UI \"node dist/cron/worker.js\" \"next start --port 8675\"", "build_and_start": "npm install && npm run update_db && npm run build && npm run start", "lint": "next lint", "update_db": "npx prisma generate && npx prisma db push", diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index 96f6399b..f0429332 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -13,26 +13,29 @@ model Settings { value String } -model Job { - id String @id @default(uuid()) - name String @unique - gpu_ids String - job_config String // JSON string - created_at DateTime @default(now()) - updated_at DateTime @updatedAt - status String @default("stopped") - stop Boolean @default(false) - step Int @default(0) - info String @default("") - speed_string String @default("") +model Queue { + id Int @id @default(autoincrement()) + gpu_ids String @unique + is_running Boolean @default(false) + + @@index([gpu_ids]) } -model Queue { - id String @id @default(uuid()) - channel String - job_id String - created_at DateTime @default(now()) - updated_at DateTime @updatedAt - status String @default("waiting") - @@index([job_id, channel]) -} \ No newline at end of file +model Job { + id String @id @default(uuid()) + name String @unique + gpu_ids String + job_config String // JSON string + created_at DateTime @default(now()) + updated_at DateTime @updatedAt + status String @default("stopped") + stop Boolean @default(false) + return_to_queue Boolean @default(false) // same as stop, but will be set to 'queued' when stopped + step Int @default(0) + info String @default("") + speed_string String @default("") + queue_position Int @default(0) + + @@index([status]) + @@index([gpu_ids]) +} diff --git a/ui/src/app/api/gpu/route.ts b/ui/src/app/api/gpu/route.ts index 8b11dbb0..54255848 100644 --- a/ui/src/app/api/gpu/route.ts +++ b/ui/src/app/api/gpu/route.ts @@ -65,7 +65,9 @@ async function getGpuStats(isWindows: boolean) { 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; // Execute command - const { stdout } = await execAsync(command); + const { stdout } = await execAsync(command, { + env: { ...process.env, CUDA_DEVICE_ORDER: 'PCI_BUS_ID' }, + }); // Parse CSV output const gpus = stdout diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts index e26c1e49..0417dcff 100644 --- a/ui/src/app/api/jobs/[jobID]/start/route.ts +++ b/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -1,12 +1,5 @@ import { NextRequest, NextResponse } from 'next/server'; import { PrismaClient } from '@prisma/client'; -import { TOOLKIT_ROOT } from '@/paths'; -import { spawn } from 'child_process'; -import path from 'path'; -import fs from 'fs'; -import os from 'os'; -import { getTrainingFolder, getHFToken } from '@/server/settings'; -const isWindows = process.platform === 'win32'; const prisma = new PrismaClient(); @@ -21,195 +14,46 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s return NextResponse.json({ error: 'Job not found' }, { status: 404 }); } - // update job status to 'running' + // get highest queue position + const highestQueuePosition = await prisma.job.aggregate({ + _max: { + queue_position: true, + }, + }); + const newQueuePosition = (highestQueuePosition._max.queue_position || 0) + 1000; + await prisma.job.update({ where: { id: jobID }, - data: { - status: 'running', - stop: false, - info: 'Starting job...', + data: { queue_position: newQueuePosition }, + }); + + // make sure the queue is running + const queue = await prisma.queue.findFirst({ + where: { + gpu_ids: job.gpu_ids, }, }); - // setup the training - const trainingRoot = await getTrainingFolder(); - - const trainingFolder = path.join(trainingRoot, job.name); - if (!fs.existsSync(trainingFolder)) { - fs.mkdirSync(trainingFolder, { recursive: true }); - } - - // make the config file - const configPath = path.join(trainingFolder, '.job_config.json'); - - //log to path - const logPath = path.join(trainingFolder, 'log.txt'); - - try { - // if the log path exists, move it to a folder called logs and rename it {num}_log.txt, looking for the highest num - // if the log path does not exist, create it - if (fs.existsSync(logPath)) { - const logsFolder = path.join(trainingFolder, 'logs'); - if (!fs.existsSync(logsFolder)) { - fs.mkdirSync(logsFolder, { recursive: true }); - } - - let num = 0; - while (fs.existsSync(path.join(logsFolder, `${num}_log.txt`))) { - num++; - } - - fs.renameSync(logPath, path.join(logsFolder, `${num}_log.txt`)); - } - } catch (e) { - console.error('Error moving log file:', e); - } - - // update the config dataset path - const jobConfig = JSON.parse(job.job_config); - jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db'); - - // write the config file - fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2)); - - let pythonPath = 'python'; - // use .venv or venv if it exists - if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) { - if (isWindows) { - pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe'); - } else { - pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python'); - } - } else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) { - if (isWindows) { - pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe'); - } else { - pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python'); - } - } - - const runFilePath = path.join(TOOLKIT_ROOT, 'run.py'); - if (!fs.existsSync(runFilePath)) { - return NextResponse.json({ error: 'run.py not found' }, { status: 500 }); - } - - const additionalEnv: any = { - AITK_JOB_ID: jobID, - CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, - IS_AI_TOOLKIT_UI: '1' - }; - - // HF_TOKEN - const hfToken = await getHFToken(); - if (hfToken && hfToken.trim() !== '') { - additionalEnv.HF_TOKEN = hfToken; - } - - // Add the --log argument to the command - const args = [runFilePath, configPath, '--log', logPath]; - - try { - let subprocess; - - if (isWindows) { - // For Windows, use 'cmd.exe' to open a new command window - subprocess = spawn('cmd.exe', ['/c', 'start', 'cmd.exe', '/k', pythonPath, ...args], { - env: { - ...process.env, - ...additionalEnv, - }, - cwd: TOOLKIT_ROOT, - windowsHide: false, - }); - } else { - // For non-Windows platforms - subprocess = spawn(pythonPath, args, { - detached: true, - stdio: ['ignore', 'pipe', 'pipe'], // Changed from 'ignore' to capture output - env: { - ...process.env, - ...additionalEnv, - }, - cwd: TOOLKIT_ROOT, - }); - } - - // Start monitoring in the background without blocking the response - const monitorProcess = async () => { - const startTime = Date.now(); - let errorOutput = ''; - let stdoutput = ''; - - if (subprocess.stderr) { - subprocess.stderr.on('data', data => { - errorOutput += data.toString(); - }); - subprocess.stdout.on('data', data => { - stdoutput += data.toString(); - // truncate to only get the last 500 characters - if (stdoutput.length > 500) { - stdoutput = stdoutput.substring(stdoutput.length - 500); - } - }); - } - - subprocess.on('exit', async code => { - const currentTime = Date.now(); - const duration = (currentTime - startTime) / 1000; - console.log(`Job ${jobID} exited with code ${code} after ${duration} seconds.`); - // wait for 5 seconds to give it time to stop itself. It id still has a status of running in the db, update it to stopped - await new Promise(resolve => setTimeout(resolve, 5000)); - const updatedJob = await prisma.job.findUnique({ - where: { id: jobID }, - }); - if (updatedJob?.status === 'running') { - let errorString = errorOutput; - if (errorString.trim() === '') { - errorString = stdoutput; - } - await prisma.job.update({ - where: { id: jobID }, - data: { - status: 'error', - info: `Error launching job: ${errorString.substring(0, 500)}`, - }, - }); - } - }); - - // Wait 30 seconds before releasing the process - await new Promise(resolve => setTimeout(resolve, 30000)); - // Detach the process for non-Windows systems - if (!isWindows && subprocess.unref) { - subprocess.unref(); - } - }; - - // Start the monitoring without awaiting it - monitorProcess().catch(err => { - console.error(`Error in process monitoring for job ${jobID}:`, err); - }); - - // Return the response immediately - return NextResponse.json(job); - } catch (error: any) { - // Handle any exceptions during process launch - console.error('Error launching process:', error); - - await prisma.job.update({ - where: { id: jobID }, + // if queue doesn't exist, create it + if (!queue) { + await prisma.queue.create({ data: { - status: 'error', - info: `Error launching job: ${error?.message || 'Unknown error'}`, + gpu_ids: job.gpu_ids, + is_running: false, }, }); - - return NextResponse.json( - { - error: 'Failed to launch job process', - details: error?.message || 'Unknown error', - }, - { status: 500 }, - ); } + + await prisma.job.update({ + where: { id: jobID }, + data: { + status: 'queued', + stop: false, + return_to_queue: false, + info: 'Job queued', + }, + }); + + // Return the response immediately + return NextResponse.json(job); } diff --git a/ui/src/app/api/jobs/route.ts b/ui/src/app/api/jobs/route.ts index c489088d..56ca8501 100644 --- a/ui/src/app/api/jobs/route.ts +++ b/ui/src/app/api/jobs/route.ts @@ -42,12 +42,21 @@ export async function POST(request: Request) { }); return NextResponse.json(training); } else { + // find the highest queue position and add 1000 + const highestQueuePosition = await prisma.job.aggregate({ + _max: { + queue_position: true, + }, + }); + const newQueuePosition = (highestQueuePosition._max.queue_position || 0) + 1000; + // Create new training const training = await prisma.job.create({ data: { name, gpu_ids, job_config: JSON.stringify(job_config), + queue_position: newQueuePosition, }, }); return NextResponse.json(training); diff --git a/ui/src/app/api/queue/[queueID]/start/route.ts b/ui/src/app/api/queue/[queueID]/start/route.ts new file mode 100644 index 00000000..d67ff4ec --- /dev/null +++ b/ui/src/app/api/queue/[queueID]/start/route.ts @@ -0,0 +1,27 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { queueID: string } }) { + const { queueID } = await params; + + const queue = await prisma.queue.findUnique({ + where: { gpu_ids: queueID }, + }); + + if (!queue) { + // create it if it doesn't exist + const newQueue = await prisma.queue.create({ + data: { gpu_ids: queueID, is_running: true }, + }); + return NextResponse.json(newQueue); + } + + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: true }, + }); + + return NextResponse.json(queue); +} diff --git a/ui/src/app/api/queue/[queueID]/stop/route.ts b/ui/src/app/api/queue/[queueID]/stop/route.ts new file mode 100644 index 00000000..87e608b1 --- /dev/null +++ b/ui/src/app/api/queue/[queueID]/stop/route.ts @@ -0,0 +1,23 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { queueID: string } }) { + const { queueID } = await params; + + const queue = await prisma.queue.findUnique({ + where: { gpu_ids: queueID }, + }); + + if (!queue) { + return NextResponse.json({ error: 'Queue not found' }, { status: 404 }); + } + + await prisma.queue.update({ + where: { id: queue.id }, + data: { is_running: false }, + }); + + return NextResponse.json(queue); +} diff --git a/ui/src/app/api/queue/route.ts b/ui/src/app/api/queue/route.ts new file mode 100644 index 00000000..08c6d4bc --- /dev/null +++ b/ui/src/app/api/queue/route.ts @@ -0,0 +1,18 @@ +import { NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; + +const prisma = new PrismaClient(); + +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + + try { + const queues = await prisma.queue.findMany({ + orderBy: { gpu_ids: 'asc' }, + }); + return NextResponse.json({ queues: queues }); + } catch (error) { + console.error(error); + return NextResponse.json({ error: 'Failed to fetch queue' }, { status: 500 }); + } +} diff --git a/ui/src/app/dashboard/page.tsx b/ui/src/app/dashboard/page.tsx index e6965ac3..bb644a5a 100644 --- a/ui/src/app/dashboard/page.tsx +++ b/ui/src/app/dashboard/page.tsx @@ -18,7 +18,7 @@ export default function Dashboard() {
-

Active Jobs

+

Queues

View All
diff --git a/ui/src/app/jobs/[jobID]/page.tsx b/ui/src/app/jobs/[jobID]/page.tsx index 96fc1936..d66f9cf5 100644 --- a/ui/src/app/jobs/[jobID]/page.tsx +++ b/ui/src/app/jobs/[jobID]/page.tsx @@ -5,7 +5,7 @@ import { FaChevronLeft } from 'react-icons/fa'; import { Button } from '@headlessui/react'; import { TopBar, MainContent } from '@/components/layout'; import useJob from '@/hooks/useJob'; -import SampleImages, {SampleImagesMenu} from '@/components/SampleImages'; +import SampleImages, { SampleImagesMenu } from '@/components/SampleImages'; import JobOverview from '@/components/JobOverview'; import { redirect } from 'next/navigation'; import JobActionBar from '@/components/JobActionBar'; @@ -73,6 +73,7 @@ export default function JobPage({ params }: { params: { jobID: string } }) { afterDelete={() => { redirect('/jobs'); }} + autoStartQueue={true} /> )} @@ -98,15 +99,12 @@ export default function JobPage({ params }: { params: { jobID: string } }) { {page.name} ))} - { - page?.menuItem && ( - <> -
-
- - - ) - } + {page?.menuItem && ( + <> +
+ + + )}
); diff --git a/ui/src/app/jobs/page.tsx b/ui/src/app/jobs/page.tsx index 5e80b313..d34e5bcf 100644 --- a/ui/src/app/jobs/page.tsx +++ b/ui/src/app/jobs/page.tsx @@ -9,7 +9,7 @@ export default function Dashboard() { <>
-

Training Jobs

+

Training Queue

diff --git a/ui/src/components/GPUWidget.tsx b/ui/src/components/GPUWidget.tsx index e21b0190..8c98f9c7 100644 --- a/ui/src/components/GPUWidget.tsx +++ b/ui/src/components/GPUWidget.tsx @@ -1,6 +1,6 @@ import React from 'react'; import { GpuInfo } from '@/types'; -import { ChevronRight, Thermometer, Zap, Clock, HardDrive, Fan, Cpu } from 'lucide-react'; +import { Thermometer, Zap, Clock, HardDrive, Fan, Cpu } from 'lucide-react'; interface GPUWidgetProps { gpu: GpuInfo; @@ -24,7 +24,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {

{gpu.name}

- #{gpu.index} + # {gpu.index}
diff --git a/ui/src/components/JobActionBar.tsx b/ui/src/components/JobActionBar.tsx index c6500c97..8ea3611b 100644 --- a/ui/src/components/JobActionBar.tsx +++ b/ui/src/components/JobActionBar.tsx @@ -1,9 +1,10 @@ import Link from 'next/link'; -import { Eye, Trash2, Pen, Play, Pause, Cog } from 'lucide-react'; +import { Eye, Trash2, Pen, Play, Pause, Cog, X } from 'lucide-react'; import { Button } from '@headlessui/react'; import { openConfirm } from '@/components/ConfirmModal'; import { Job } from '@prisma/client'; import { startJob, stopJob, deleteJob, getAvaliableJobActions, markJobAsStopped } from '@/utils/jobs'; +import { startQueue } from '@/utils/queue'; import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react'; interface JobActionBarProps { @@ -12,10 +13,18 @@ interface JobActionBarProps { afterDelete?: () => void; hideView?: boolean; className?: string; + autoStartQueue?: boolean; } -export default function JobActionBar({ job, onRefresh, afterDelete, className, hideView }: JobActionBarProps) { - const { canStart, canStop, canDelete, canEdit } = getAvaliableJobActions(job); +export default function JobActionBar({ + job, + onRefresh, + afterDelete, + className, + hideView, + autoStartQueue = false, +}: JobActionBarProps) { + const { canStart, canStop, canDelete, canEdit, canRemoveFromQueue } = getAvaliableJobActions(job); if (!afterDelete) afterDelete = onRefresh; @@ -26,6 +35,10 @@ export default function JobActionBar({ job, onRefresh, afterDelete, className, h onClick={async () => { if (!canStart) return; await startJob(job.id); + // start the queue as well + if (autoStartQueue) { + await startQueue(job.gpu_ids); + } if (onRefresh) onRefresh(); }} className={`ml-2 opacity-100`} @@ -33,6 +46,18 @@ export default function JobActionBar({ job, onRefresh, afterDelete, className, h )} + {canRemoveFromQueue && ( + + )} {canStop && ( + + ) : ( + <> + Queue Stopped + + + )} +
+
+ + + ); + })} + {!onlyActive && Object.keys(jobsDict).includes('Idle') && ( +
+
+
+

Idle

+
+
+ +
+ )} + + ); } diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index a5b3e2d6..ca6a21ca 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -1,17 +1,18 @@ import Link from 'next/link'; -import { Home, Settings, BrainCircuit, Images, Plus} from 'lucide-react'; -import { FaXTwitter, FaDiscord, FaYoutube } from "react-icons/fa6"; +import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react'; +import { FaXTwitter, FaDiscord, FaYoutube } from 'react-icons/fa6'; const Sidebar = () => { const navigation = [ { name: 'Dashboard', href: '/dashboard', icon: Home }, { name: 'New Job', href: '/jobs/new', icon: Plus }, - { name: 'Training Jobs', href: '/jobs', icon: BrainCircuit }, + { name: 'Training Queue', href: '/jobs', icon: BrainCircuit }, { name: 'Datasets', href: '/datasets', icon: Images }, { name: 'Settings', href: '/settings', icon: Settings }, ]; - const socialsBoxClass = 'flex flex-col items-center justify-center p-1 hover:bg-gray-800 rounded-lg transition-colors'; + const socialsBoxClass = + 'flex flex-col items-center justify-center p-1 hover:bg-gray-800 rounded-lg transition-colors'; const socialIconClass = 'w-5 h-5 text-gray-400 hover:text-white'; return ( @@ -60,30 +61,15 @@ const Sidebar = () => { {/* Social links grid */}
- + {/* Discord */} - + {/* YouTube */} - + {/* X */} diff --git a/ui/src/components/UniversalTable.tsx b/ui/src/components/UniversalTable.tsx index b86b9bc8..036e3711 100644 --- a/ui/src/components/UniversalTable.tsx +++ b/ui/src/components/UniversalTable.tsx @@ -16,10 +16,17 @@ interface TableProps { columns: TableColumn[]; rows: TableRow[]; isLoading: boolean; + theadClassName?: string; onRefresh: () => void; } -export default function UniversalTable({ columns, rows, isLoading, onRefresh = () => {} }: TableProps) { +export default function UniversalTable({ + columns, + rows, + isLoading, + theadClassName = 'text-gray-400', + onRefresh = () => {}, +}: TableProps) { return (
{isLoading ? ( @@ -39,7 +46,7 @@ export default function UniversalTable({ columns, rows, isLoading, onRefresh = ( ) : (
- + {columns.map(column => (
diff --git a/ui/src/hooks/useJobsList.tsx b/ui/src/hooks/useJobsList.tsx index 6f1e3af9..615bf423 100644 --- a/ui/src/hooks/useJobsList.tsx +++ b/ui/src/hooks/useJobsList.tsx @@ -4,7 +4,7 @@ import { useEffect, useState } from 'react'; import { Job } from '@prisma/client'; import { apiClient } from '@/utils/api'; -export default function useJobsList(onlyActive = false) { +export default function useJobsList(onlyActive = false, reloadInterval: null | number = null) { const [jobs, setJobs] = useState([]); const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); @@ -20,19 +20,26 @@ export default function useJobsList(onlyActive = false) { setStatus('error'); } else { if (onlyActive) { - data.jobs = data.jobs.filter((job: Job) => job.status === 'running'); + data.jobs = data.jobs.filter((job: Job) => ['running', 'queued', 'stopping'].includes(job.status)); } setJobs(data.jobs); setStatus('success'); } }) .catch(error => { - console.error('Error fetching datasets:', error); + console.error('Error fetching jobs:', error); setStatus('error'); }); }; useEffect(() => { refreshJobs(); + + if (reloadInterval) { + const interval = setInterval(() => { + refreshJobs(); + }, reloadInterval); + return () => clearInterval(interval); + } }, []); return { jobs, setJobs, status, refreshJobs }; diff --git a/ui/src/hooks/useQueueList.tsx b/ui/src/hooks/useQueueList.tsx new file mode 100644 index 00000000..429ac012 --- /dev/null +++ b/ui/src/hooks/useQueueList.tsx @@ -0,0 +1,36 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { Queue } from '@prisma/client'; +import { apiClient } from '@/utils/api'; + +export default function useQueueList() { + const [queues, setQueues] = useState([]); + const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); + + const refreshQueues = () => { + setStatus('loading'); + apiClient + .get('/api/queue') + .then(res => res.data) + .then(data => { + console.log('Queues:', data); + if (data.error) { + console.log('Error fetching queues:', data.error); + setStatus('error'); + } else { + setQueues(data.queues); + setStatus('success'); + } + }) + .catch(error => { + console.error('Error fetching queues:', error); + setStatus('error'); + }); + }; + useEffect(() => { + refreshQueues(); + }, []); + + return { queues, setQueues, status, refreshQueues }; +} diff --git a/ui/src/types.ts b/ui/src/types.ts index d86220a6..a0b871e6 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -246,3 +246,5 @@ export interface GroupedSelectOption { readonly label: string; readonly options: SelectOption[]; } + +export type JobStatus = 'queued' | 'running' | 'stopping' | 'stopped' | 'completed' | 'error'; diff --git a/ui/src/utils/jobs.ts b/ui/src/utils/jobs.ts index ed656edc..8e485458 100644 --- a/ui/src/utils/jobs.ts +++ b/ui/src/utils/jobs.ts @@ -73,15 +73,16 @@ export const getJobConfig = (job: Job) => { export const getAvaliableJobActions = (job: Job) => { const jobConfig = getJobConfig(job); const isStopping = job.stop && job.status === 'running'; - const canDelete = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping; - const canEdit = ['completed', 'stopped', 'error'].includes(job.status) && !isStopping; + const canDelete = ['queued', 'completed', 'stopped', 'error'].includes(job.status) && !isStopping; + const canEdit = ['queued','completed', 'stopped', 'error'].includes(job.status) && !isStopping; + const canRemoveFromQueue = job.status === 'queued'; const canStop = job.status === 'running' && !isStopping; let canStart = ['stopped', 'error'].includes(job.status) && !isStopping; // can resume if more steps were added if (job.status === 'completed' && jobConfig.config.process[0].train.steps > job.step && !isStopping) { canStart = true; } - return { canDelete, canEdit, canStop, canStart }; + return { canDelete, canEdit, canStop, canStart, canRemoveFromQueue }; }; export const getNumberOfSamples = (job: Job) => { diff --git a/ui/src/utils/queue.ts b/ui/src/utils/queue.ts new file mode 100644 index 00000000..e25f2850 --- /dev/null +++ b/ui/src/utils/queue.ts @@ -0,0 +1,32 @@ +import { apiClient } from '@/utils/api'; + +export const startQueue = (queueID: string) => { + return new Promise((resolve, reject) => { + apiClient + .get(`/api/queue/${queueID}/start`) + .then(res => res.data) + .then(data => { + console.log('Queue started:', data); + resolve(); + }) + .catch(error => { + console.error('Error starting queue:', error); + reject(error); + }); + }); +}; +export const stopQueue = (queueID: string) => { + return new Promise((resolve, reject) => { + apiClient + .get(`/api/queue/${queueID}/stop`) + .then(res => res.data) + .then(data => { + console.log('Queue stopped:', data); + resolve(); + }) + .catch(error => { + console.error('Error stopping queue:', error); + reject(error); + }); + }); +}; diff --git a/ui/tsconfig.worker.json b/ui/tsconfig.worker.json index 6b4d9531..d459d1fc 100644 --- a/ui/tsconfig.worker.json +++ b/ui/tsconfig.worker.json @@ -3,11 +3,18 @@ "compilerOptions": { "module": "commonjs", "target": "es2020", - "outDir": "dist", + "outDir": "dist/cron", "moduleResolution": "node", "types": [ "node" - ] + ], + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "paths": { + "@/*": [ + "./cron/*" + ] + } }, "include": [ "cron/**/*.ts" diff --git a/version.py b/version.py index a3555564..30854d5d 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.6.5" \ No newline at end of file +VERSION = "0.7.0" \ No newline at end of file