mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Start, stop, monitor jobs from ui working.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sqlite3
|
||||
import asyncio
|
||||
@@ -9,9 +10,10 @@ AITK_Status = Literal["running", "stopped", "error", "completed"]
|
||||
|
||||
|
||||
class UITrainer(SDTrainer):
|
||||
def __init__(self):
|
||||
super(UITrainer, self).__init__()
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "data.sqlite")
|
||||
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
|
||||
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||
if self.job_id is None:
|
||||
raise Exception("AITK_JOB_ID not set")
|
||||
@@ -19,13 +21,31 @@ class UITrainer(SDTrainer):
|
||||
# Create a thread pool for database operations
|
||||
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
# Initialize the status
|
||||
asyncio.run(self._update_status("running", "Starting"))
|
||||
|
||||
self._run_async_operation(self._update_status("running", "Starting"))
|
||||
|
||||
def _run_async_operation(self, coro):
|
||||
"""Helper method to run an async coroutine in a new event loop."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
if loop.is_running():
|
||||
# If we're already in an event loop, create a future
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
# We could wait for the result if needed: future.result()
|
||||
else:
|
||||
# If no loop is running, run the coroutine and close the loop
|
||||
loop.run_until_complete(coro)
|
||||
|
||||
async def _execute_db_operation(self, operation_func):
|
||||
"""Execute a database operation in a separate thread to avoid blocking."""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.thread_pool, operation_func)
|
||||
|
||||
|
||||
def _db_connect(self):
|
||||
"""Create a new connection for each operation to avoid locking."""
|
||||
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
|
||||
@@ -36,47 +56,49 @@ class UITrainer(SDTrainer):
|
||||
def _check_stop():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT stop FROM jobs WHERE job_id = ?", (self.job_id,))
|
||||
cursor.execute(
|
||||
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
|
||||
stop = cursor.fetchone()
|
||||
return False if stop is None else stop[0] == 1
|
||||
|
||||
|
||||
# For this one we need a synchronous result, so we'll run it directly
|
||||
return _check_stop()
|
||||
|
||||
def maybe_stop(self):
|
||||
if self.should_stop():
|
||||
asyncio.run(self._update_status("stopped", "Job stopped"))
|
||||
self._run_async_operation(
|
||||
self._update_status("stopped", "Job stopped"))
|
||||
self.is_stopping = True
|
||||
raise Exception("Job stopped")
|
||||
|
||||
async def _update_step(self):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("BEGIN IMMEDIATE") # Get an immediate lock
|
||||
try:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET step = ? WHERE job_id = ?",
|
||||
"UPDATE Job SET step = ? WHERE id = ?",
|
||||
(self.step_num, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_step(self):
|
||||
"""Non-blocking update of the step count."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Start the async operation without waiting for it
|
||||
asyncio.create_task(self._update_step())
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_step())
|
||||
|
||||
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
if not self.accelerator.is_main_process:
|
||||
return
|
||||
|
||||
|
||||
def _do_update():
|
||||
with self._db_connect() as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -84,24 +106,24 @@ class UITrainer(SDTrainer):
|
||||
try:
|
||||
if info is not None:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET status = ?, info = ? WHERE job_id = ?",
|
||||
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
|
||||
(status, info, self.job_id)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"UPDATE jobs SET status = ? WHERE job_id = ?",
|
||||
"UPDATE Job SET status = ? WHERE id = ?",
|
||||
(status, self.job_id)
|
||||
)
|
||||
finally:
|
||||
cursor.execute("COMMIT") # Release the lock
|
||||
|
||||
|
||||
await self._execute_db_operation(_do_update)
|
||||
|
||||
def update_status(self, status: AITK_Status, info: Optional[str] = None):
|
||||
"""Non-blocking update of status."""
|
||||
if self.accelerator.is_main_process:
|
||||
# Start the async operation without waiting for it
|
||||
asyncio.create_task(self._update_status(status, info))
|
||||
# Use the helper method to run the async operation
|
||||
self._run_async_operation(self._update_status(status, info))
|
||||
|
||||
def on_error(self, e: Exception):
|
||||
super(UITrainer, self).on_error(e)
|
||||
@@ -121,30 +143,36 @@ class UITrainer(SDTrainer):
|
||||
|
||||
def hook_before_model_load(self):
|
||||
super().hook_before_model_load()
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading model")
|
||||
|
||||
def before_dataset_load(self):
|
||||
super().before_dataset_load()
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Loading dataset")
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
def sample_step_hook(self, img_num, total_imgs):
|
||||
super().sample_step_hook(img_num, total_imgs)
|
||||
# subtract a since this is called after the image is generated
|
||||
self.maybe_stop()
|
||||
self.update_status(
|
||||
"running", f"Generating images - {img_num - 1} of {total_imgs}")
|
||||
"running", f"Generating images - {img_num + 1}/{total_imgs}")
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
self.maybe_stop()
|
||||
total_imgs = len(self.sample_config.prompts)
|
||||
self.update_status("running", f"Generating images - 1 of {total_imgs}")
|
||||
self.update_status("running", f"Generating images - 0/{total_imgs}")
|
||||
super().sample(step, is_first)
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
|
||||
def save(self, step=None):
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Saving model")
|
||||
super().save(step)
|
||||
self.update_status("running", "Training")
|
||||
self.maybe_stop()
|
||||
self.update_status("running", "Training")
|
||||
|
||||
@@ -1428,7 +1428,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
self.sd.add_after_sample_image_hook(self.after_sample_image_hook)
|
||||
self.sd.add_after_sample_image_hook(self.sample_step_hook)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
5
run.py
5
run.py
@@ -88,7 +88,10 @@ def main():
|
||||
except Exception as e:
|
||||
print_acc(f"Error running job: {e}")
|
||||
jobs_failed += 1
|
||||
job.process[0].on_error(e)
|
||||
try:
|
||||
job.process[0].on_error(e)
|
||||
except Exception as e2:
|
||||
print_acc(f"Error running on_error: {e2}")
|
||||
if not args.recover:
|
||||
print_end_message(jobs_completed, jobs_failed)
|
||||
raise e
|
||||
|
||||
@@ -4,7 +4,7 @@ generator client {
|
||||
|
||||
datasource db {
|
||||
provider = "sqlite"
|
||||
url = "file:../aitk_db.db"
|
||||
url = "file:../../aitk_db.db"
|
||||
}
|
||||
|
||||
model Settings {
|
||||
@@ -13,9 +13,9 @@ model Settings {
|
||||
value String
|
||||
}
|
||||
|
||||
model Training {
|
||||
model Job {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
name String @unique
|
||||
gpu_id Int
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
|
||||
90
ui/src/app/api/jobs/[jobID]/start/route.ts
Normal file
90
ui/src/app/api/jobs/[jobID]/start/route.ts
Normal file
@@ -0,0 +1,90 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
status: 'running',
|
||||
stop: false,
|
||||
info: 'Starting job...',
|
||||
},
|
||||
});
|
||||
|
||||
// setup the training
|
||||
const settings = await prisma.settings.findMany();
|
||||
const settingsObject = settings.reduce((acc: any, setting) => {
|
||||
acc[setting.key] = setting.value;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// if TRAINING_FOLDER is not set, use default
|
||||
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
||||
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
||||
}
|
||||
|
||||
const trainingFolder = path.join(settingsObject.TRAINING_FOLDER, job.name);
|
||||
if (!fs.existsSync(trainingFolder)) {
|
||||
fs.mkdirSync(trainingFolder, { recursive: true });
|
||||
}
|
||||
|
||||
// make the config file
|
||||
const configPath = path.join(trainingFolder, '.job_config.json');
|
||||
|
||||
// update the config dataset path
|
||||
const jobConfig = JSON.parse(job.job_config);
|
||||
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
|
||||
|
||||
|
||||
// write the config file
|
||||
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
|
||||
|
||||
let pythonPath = 'python';
|
||||
// use .venv or venv if it exists
|
||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||
}
|
||||
|
||||
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
||||
if (!fs.existsSync(runFilePath)) {
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
|
||||
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_id} ${pythonPath} ${runFilePath} ${configPath}`);
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
detached: true,
|
||||
stdio: 'ignore',
|
||||
env: {
|
||||
...process.env,
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_id}`,
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
subprocess.unref();
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
23
ui/src/app/api/jobs/[jobID]/stop/route.ts
Normal file
23
ui/src/app/api/jobs/[jobID]/stop/route.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
// update job status to 'running'
|
||||
await prisma.job.update({
|
||||
where: { id: jobID },
|
||||
data: {
|
||||
stop: true,
|
||||
info: 'Stopping job...',
|
||||
},
|
||||
});
|
||||
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
@@ -9,17 +9,18 @@ export async function GET(request: Request) {
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
const training = await prisma.training.findUnique({
|
||||
const training = await prisma.job.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
}
|
||||
|
||||
const trainings = await prisma.training.findMany({
|
||||
const trainings = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
return NextResponse.json(trainings);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
@@ -31,7 +32,7 @@ export async function POST(request: Request) {
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
const training = await prisma.training.update({
|
||||
const training = await prisma.job.update({
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
@@ -42,7 +43,7 @@ export async function POST(request: Request) {
|
||||
return NextResponse.json(training);
|
||||
} else {
|
||||
// Create new training
|
||||
const training = await prisma.training.create({
|
||||
const training = await prisma.job.create({
|
||||
data: {
|
||||
name,
|
||||
gpu_id,
|
||||
72
ui/src/app/jobs/[jobID]/page.tsx
Normal file
72
ui/src/app/jobs/[jobID]/page.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, use } from 'react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import useJob from '@/hooks/useJob';
|
||||
import { startJob, stopJob } from '@/utils/jobs';
|
||||
|
||||
export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
const usableParams = use(params as any) as { jobID: string };
|
||||
const jobID = usableParams.jobID;
|
||||
const { job, status, refreshJobs } = useJob(jobID, 5000);
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Fixed top bar */}
|
||||
<TopBar>
|
||||
<div>
|
||||
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
||||
<FaChevronLeft />
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-lg">Job: {job?.name}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
{job?.status === 'running' && (
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await stopJob(jobID);
|
||||
refreshJobs();
|
||||
}}
|
||||
className="bg-red-500 text-white px-4 py-1 rounded-sm"
|
||||
>
|
||||
Stop
|
||||
</Button>
|
||||
)}
|
||||
{(job?.status === 'stopped' || job?.status === 'error') && (
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await startJob(jobID);
|
||||
refreshJobs();
|
||||
}}
|
||||
className="bg-green-800 text-white px-4 py-1 rounded-sm"
|
||||
>
|
||||
Start
|
||||
</Button>
|
||||
)}
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
{status === 'loading' && job == null && <p>Loading...</p>}
|
||||
{status === 'error' && job == null && <p>Error fetching job</p>}
|
||||
{job && (
|
||||
<>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||
<div className="">
|
||||
<h2 className="text-lg font-semibold">Job Details</h2>
|
||||
<p className="text-gray-400">ID: {job.id}</p>
|
||||
<p className="text-gray-400">Name: {job.name}</p>
|
||||
<p className="text-gray-400">GPU: {job.gpu_id}</p>
|
||||
<p className="text-gray-400">Status: {job.status}</p>
|
||||
<p className="text-gray-400">Info: {job.info}</p>
|
||||
<p className="text-gray-400">Step: {job.step}</p>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -15,6 +15,8 @@ import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import useDatasetList from '@/hooks/useDatasetList';
|
||||
import path from 'path';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
|
||||
export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
@@ -47,7 +49,7 @@ export default function TrainingForm() {
|
||||
|
||||
useEffect(() => {
|
||||
if (runId) {
|
||||
fetch(`/api/training?id=${runId}`)
|
||||
fetch(`/api/jobs?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setGpuID(data.gpu_id);
|
||||
@@ -76,7 +78,7 @@ export default function TrainingForm() {
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/training', {
|
||||
const response = await fetch('/api/jobs', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -94,7 +96,7 @@ export default function TrainingForm() {
|
||||
setStatus('success');
|
||||
if (!runId) {
|
||||
const data = await response.json();
|
||||
router.push(`/training?id=${data.id}`);
|
||||
router.push(`/jobs/${data.id}`);
|
||||
}
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
@@ -108,7 +110,12 @@ export default function TrainingForm() {
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-lg">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
|
||||
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
||||
<FaChevronLeft />
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
</TopBar>
|
||||
29
ui/src/app/jobs/page.tsx
Normal file
29
ui/src/app/jobs/page.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
'use client';
|
||||
|
||||
import JobsTable from '@/components/JobsTable';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import Link from 'next/link';
|
||||
|
||||
export default function Dashboard() {
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
<div>
|
||||
<h1 className="text-lg">Training Jobs</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Link
|
||||
href="/jobs/new"
|
||||
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
||||
>
|
||||
New Training Job
|
||||
</Link>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<JobsTable />
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -108,9 +108,9 @@ const GpuMonitor: React.FC = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-2">
|
||||
<div className="">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<h1 className="text-lg font-bold">GPU Monitor</h1>
|
||||
<h1 className="text-md">GPU Monitor</h1>
|
||||
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
||||
</div>
|
||||
|
||||
|
||||
83
ui/src/components/JobsTable.tsx
Normal file
83
ui/src/components/JobsTable.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import useJobsList from '@/hooks/useJobsList';
|
||||
import Loading from './Loading';
|
||||
import { JobConfig } from '@/types';
|
||||
import Link from 'next/link';
|
||||
|
||||
interface JobsTableProps {}
|
||||
|
||||
export default function JobsTable(props: JobsTableProps) {
|
||||
const { jobs, status, refreshJobs } = useJobsList();
|
||||
const isLoading = status === 'loading';
|
||||
|
||||
return (
|
||||
<div className="w-full bg-gray-900 rounded-md shadow-md">
|
||||
{isLoading ? (
|
||||
<div className="p-4 flex justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
) : jobs.length === 0 ? (
|
||||
<div className="p-6 text-center text-gray-400">
|
||||
<p className="text-sm">No jobs available</p>
|
||||
<button
|
||||
onClick={() => refreshJobs()}
|
||||
className="mt-2 px-3 py-1 text-xs bg-gray-800 hover:bg-gray-700 text-gray-300 rounded transition-colors"
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<table className="w-full text-sm text-left text-gray-300">
|
||||
<thead className="text-xs uppercase bg-gray-800 text-gray-400">
|
||||
<tr>
|
||||
<th className="px-3 py-2">Name</th>
|
||||
<th className="px-3 py-2">Steps</th>
|
||||
<th className="px-3 py-2">GPU</th>
|
||||
<th className="px-3 py-2">Status</th>
|
||||
<th className="px-3 py-2">Info</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{jobs.map((job, index) => {
|
||||
const jobConfig: JobConfig = JSON.parse(job.job_config);
|
||||
const totalSteps = jobConfig.config.process[0].train.steps;
|
||||
|
||||
// Style for alternating rows
|
||||
const rowClass = index % 2 === 0 ? 'bg-gray-900' : 'bg-gray-800';
|
||||
|
||||
// Style based on job status
|
||||
let statusClass = 'text-gray-400';
|
||||
if (job.status === 'completed') statusClass = 'text-green-400';
|
||||
if (job.status === 'failed') statusClass = 'text-red-400';
|
||||
if (job.status === 'running') statusClass = 'text-blue-400';
|
||||
|
||||
return (
|
||||
<tr key={job.id} className={`${rowClass} border-b border-gray-700 hover:bg-gray-700`}>
|
||||
<td className="px-3 py-2 font-medium whitespace-nowrap">
|
||||
<Link href={`/jobs/${job.id}`}>{job.name}</Link></td>
|
||||
<td className="px-3 py-2">
|
||||
<div className="flex items-center">
|
||||
<span>
|
||||
{job.step} / {totalSteps}
|
||||
</span>
|
||||
<div className="w-16 bg-gray-700 rounded-full h-1.5 ml-2">
|
||||
<div
|
||||
className="bg-blue-500 h-1.5 rounded-full"
|
||||
style={{ width: `${(job.step / totalSteps) * 100}%` }}
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-3 py-2">{job.gpu_id}</td>
|
||||
<td className={`px-3 py-2 ${statusClass}`}>{job.status}</td>
|
||||
<td className="px-3 py-2 truncate max-w-xs">{job.info}</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
|
||||
const Sidebar = () => {
|
||||
const navigation = [
|
||||
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
||||
{ name: 'Train', href: '/train', icon: BrainCircuit },
|
||||
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
|
||||
{ name: 'Datasets', href: '/datasets', icon: Images },
|
||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||
];
|
||||
|
||||
40
ui/src/hooks/useJob.tsx
Normal file
40
ui/src/hooks/useJob.tsx
Normal file
@@ -0,0 +1,40 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useJob(jobID: string, reloadInterval: null | number = null) {
|
||||
const [job, setJob] = useState<Job | null>(null);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshJobs = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs?id=${jobID}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job:', data);
|
||||
setJob(data);
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshJobs();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshJobs();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
}
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { job, setJob, status, refreshJobs };
|
||||
}
|
||||
29
ui/src/hooks/useJobsList.tsx
Normal file
29
ui/src/hooks/useJobsList.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useJobsList() {
|
||||
const [jobs, setJobs] = useState<Job[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshJobs = () => {
|
||||
setStatus('loading');
|
||||
fetch('/api/jobs')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Jobs:', data);
|
||||
setJobs(data);
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
useEffect(() => {
|
||||
refreshJobs();
|
||||
}, []);
|
||||
|
||||
return { jobs, setJobs, status, refreshJobs };
|
||||
}
|
||||
29
ui/src/utils/jobs.ts
Normal file
29
ui/src/utils/jobs.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
export const startJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/start`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job started:', data);
|
||||
resolve();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error starting job:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
export const stopJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/stop`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Job stopped:', data);
|
||||
resolve();
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error stopping job:', error);
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
};
|
||||
Reference in New Issue
Block a user