diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index b28c43bf..c23826ae 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -16,7 +16,7 @@ model Settings { model Job { id String @id @default(uuid()) name String @unique - gpu_id Int + gpu_ids String job_config String // JSON string created_at DateTime @default(now()) updated_at DateTime @updatedAt diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts index b706043f..1ca970ac 100644 --- a/ui/src/app/api/jobs/[jobID]/start/route.ts +++ b/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -70,7 +70,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s return NextResponse.json({ error: 'run.py not found' }, { status: 500 }); } - console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_id} ${pythonPath} ${runFilePath} ${configPath}`); + console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`); // start job const subprocess = spawn(pythonPath, [runFilePath, configPath], { @@ -79,7 +79,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s env: { ...process.env, AITK_JOB_ID: jobID, - CUDA_VISIBLE_DEVICES: `${job.gpu_id}`, + CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, }, cwd: TOOLKIT_ROOT, }); diff --git a/ui/src/app/api/jobs/route.ts b/ui/src/app/api/jobs/route.ts index 3e0df9f0..4d9b954e 100644 --- a/ui/src/app/api/jobs/route.ts +++ b/ui/src/app/api/jobs/route.ts @@ -9,16 +9,16 @@ export async function GET(request: Request) { try { if (id) { - const training = await prisma.job.findUnique({ + const job = await prisma.job.findUnique({ where: { id }, }); - return NextResponse.json(training); + return NextResponse.json(job); } - const trainings = await prisma.job.findMany({ + const jobs = await prisma.job.findMany({ orderBy: { created_at: 'desc' }, }); - return NextResponse.json(trainings); + return NextResponse.json({ jobs: jobs }); } catch (error) { console.error(error); return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 }); @@ -28,7 +28,7 @@ export async function GET(request: Request) { export async function POST(request: Request) { try { const body = await request.json(); - const { id, name, job_config, gpu_id } = body; + const { id, name, job_config, gpu_ids } = body; if (id) { // Update existing training @@ -36,7 +36,7 @@ export async function POST(request: Request) { where: { id }, data: { name, - gpu_id, + gpu_ids, job_config: JSON.stringify(job_config), }, }); @@ -46,7 +46,7 @@ export async function POST(request: Request) { const training = await prisma.job.create({ data: { name, - gpu_id, + gpu_ids, job_config: JSON.stringify(job_config), }, }); diff --git a/ui/src/app/jobs/[jobID]/page.tsx b/ui/src/app/jobs/[jobID]/page.tsx index 4929804f..e28ed956 100644 --- a/ui/src/app/jobs/[jobID]/page.tsx +++ b/ui/src/app/jobs/[jobID]/page.tsx @@ -58,7 +58,7 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
ID: {job.id}
Name: {job.name}
-GPU: {job.gpu_id}
+GPUs: {job.gpu_ids}
Status: {job.status}
Info: {job.info}
Step: {job.step}
diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index 4473df18..e5247da1 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -22,7 +22,7 @@ export default function TrainingForm() { const router = useRouter(); const searchParams = useSearchParams(); const runId = searchParams.get('id'); - const [gpuID, setGpuID] = useState