mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 08:13:58 +00:00
Start, stop, monitor jobs from ui working.
This commit is contained in:
90
ui/src/app/api/jobs/[jobID]/start/route.ts
Normal file
90
ui/src/app/api/jobs/[jobID]/start/route.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
status: 'running',
|
||||
stop: false,
|
||||
info: 'Starting job...',
|
||||
},
|
||||
});
|
||||
|
||||
// setup the training
|
||||
const settings = await prisma.settings.findMany();
|
||||
const settingsObject = settings.reduce((acc: any, setting) => {
|
||||
acc[setting.key] = setting.value;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// if TRAINING_FOLDER is not set, use default
|
||||
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
||||
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
||||
}
|
||||
|
||||
const trainingFolder = path.join(settingsObject.TRAINING_FOLDER, job.name);
|
||||
if (!fs.existsSync(trainingFolder)) {
|
||||
fs.mkdirSync(trainingFolder, { recursive: true });
|
||||
}
|
||||
|
||||
// make the config file
|
||||
const configPath = path.join(trainingFolder, '.job_config.json');
|
||||
|
||||
// update the config dataset path
|
||||
const jobConfig = JSON.parse(job.job_config);
|
||||
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
|
||||
|
||||
|
||||
// write the config file
|
||||
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
|
||||
|
||||
let pythonPath = 'python';
|
||||
// use .venv or venv if it exists
|
||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||
}
|
||||
|
||||
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
||||
if (!fs.existsSync(runFilePath)) {
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
|
||||
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_id} ${pythonPath} ${runFilePath} ${configPath}`);
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
detached: true,
|
||||
stdio: 'ignore',
|
||||
env: {
|
||||
...process.env,
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_id}`,
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
subprocess.unref();
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
23
ui/src/app/api/jobs/[jobID]/stop/route.ts
Normal file
23
ui/src/app/api/jobs/[jobID]/stop/route.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
stop: true,
|
||||
info: 'Stopping job...',
|
||||
},
|
||||
});
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
@@ -9,17 +9,18 @@ export async function GET(request: Request) {
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
const training = await prisma.training.findUnique({
|
||||
const training = await prisma.job.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
}
|
||||
|
||||
const trainings = await prisma.training.findMany({
|
||||
const trainings = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
return NextResponse.json(trainings);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
@@ -31,7 +32,7 @@ export async function POST(request: Request) {
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
const training = await prisma.training.update({
|
||||
const training = await prisma.job.update({
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
@@ -42,7 +43,7 @@ export async function POST(request: Request) {
|
||||
return NextResponse.json(training);
|
||||
} else {
|
||||
// Create new training
|
||||
const training = await prisma.training.create({
|
||||
const training = await prisma.job.create({
|
||||
data: {
|
||||
name,
|
||||
gpu_id,
|
||||
Reference in New Issue
Block a user