diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts index dcf5795d..758f1d68 100644 --- a/ui/src/app/api/jobs/[jobID]/start/route.ts +++ b/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -4,7 +4,7 @@ import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths'; import { spawn } from 'child_process'; import path from 'path'; import fs from 'fs'; -import { getTrainingFolder } from '@/server/settings'; +import { getTrainingFolder, getHFToken } from '@/server/settings'; const prisma = new PrismaClient(); @@ -60,11 +60,21 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s if (!fs.existsSync(runFilePath)) { return NextResponse.json({ error: 'run.py not found' }, { status: 500 }); } + const additionalEnv: any = { + AITK_JOB_ID: jobID, + CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, + }; - console.log( - 'Spawning command:', - `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`, - ); + // HF_TOKEN + const hfToken = await getHFToken(); + if (hfToken && hfToken.trim() !== '') { + additionalEnv.HF_TOKEN = hfToken; + } + + // console.log( + // 'Spawning command:', + // `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`, + // ); // start job const subprocess = spawn(pythonPath, [runFilePath, configPath], { @@ -72,8 +82,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s stdio: 'ignore', env: { ...process.env, - AITK_JOB_ID: jobID, - CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`, + ...additionalEnv, }, cwd: TOOLKIT_ROOT, }); diff --git a/ui/src/server/settings.ts b/ui/src/server/settings.ts index a93d2b95..efebc8cc 100644 --- a/ui/src/server/settings.ts +++ b/ui/src/server/settings.ts @@ -29,7 +29,6 @@ export const getDatasetsRoot = async () => { return datasetsPath as string; }; - export const getTrainingFolder = async () => { const key = 'TRAINING_FOLDER'; let trainingRoot = myCache.get(key) as string; @@ -48,3 +47,22 @@ export const getTrainingFolder = async () => { myCache.set(key, trainingRoot); return trainingRoot as string; }; + +export const getHFToken = async () => { + const key = 'HF_TOKEN'; + let token = myCache.get(key) as string; + if (token) { + return token; + } + let row = await prisma.settings.findFirst({ + where: { + key: key, + }, + }); + token = ''; + if (row?.value && row.value !== '') { + token = row.value; + } + myCache.set(key, token); + return token; +};