mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add HF token to env when spawing via ui
This commit is contained in:
@@ -4,7 +4,7 @@ import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
@@ -60,11 +60,21 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
if (!fs.existsSync(runFilePath)) {
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
const additionalEnv: any = {
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
|
||||
};
|
||||
|
||||
console.log(
|
||||
'Spawning command:',
|
||||
`AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
|
||||
);
|
||||
// HF_TOKEN
|
||||
const hfToken = await getHFToken();
|
||||
if (hfToken && hfToken.trim() !== '') {
|
||||
additionalEnv.HF_TOKEN = hfToken;
|
||||
}
|
||||
|
||||
// console.log(
|
||||
// 'Spawning command:',
|
||||
// `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
|
||||
// );
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
@@ -72,8 +82,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
stdio: 'ignore',
|
||||
env: {
|
||||
...process.env,
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
|
||||
...additionalEnv,
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
@@ -29,7 +29,6 @@ export const getDatasetsRoot = async () => {
|
||||
return datasetsPath as string;
|
||||
};
|
||||
|
||||
|
||||
export const getTrainingFolder = async () => {
|
||||
const key = 'TRAINING_FOLDER';
|
||||
let trainingRoot = myCache.get(key) as string;
|
||||
@@ -48,3 +47,22 @@ export const getTrainingFolder = async () => {
|
||||
myCache.set(key, trainingRoot);
|
||||
return trainingRoot as string;
|
||||
};
|
||||
|
||||
export const getHFToken = async () => {
|
||||
const key = 'HF_TOKEN';
|
||||
let token = myCache.get(key) as string;
|
||||
if (token) {
|
||||
return token;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: key,
|
||||
},
|
||||
});
|
||||
token = '';
|
||||
if (row?.value && row.value !== '') {
|
||||
token = row.value;
|
||||
}
|
||||
myCache.set(key, token);
|
||||
return token;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user