mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Mor ui work
This commit is contained in:
@@ -16,7 +16,7 @@ model Settings {
|
||||
model Job {
|
||||
id String @id @default(uuid())
|
||||
name String @unique
|
||||
gpu_id Int
|
||||
gpu_ids String
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
|
||||
@@ -70,7 +70,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
|
||||
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_id} ${pythonPath} ${runFilePath} ${configPath}`);
|
||||
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`);
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
@@ -79,7 +79,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
env: {
|
||||
...process.env,
|
||||
AITK_JOB_ID: jobID,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_id}`,
|
||||
CUDA_VISIBLE_DEVICES: `${job.gpu_ids}`,
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
@@ -9,16 +9,16 @@ export async function GET(request: Request) {
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
const training = await prisma.job.findUnique({
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
return NextResponse.json(job);
|
||||
}
|
||||
|
||||
const trainings = await prisma.job.findMany({
|
||||
const jobs = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
return NextResponse.json(trainings);
|
||||
return NextResponse.json({ jobs: jobs });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
||||
@@ -28,7 +28,7 @@ export async function GET(request: Request) {
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { id, name, job_config, gpu_id } = body;
|
||||
const { id, name, job_config, gpu_ids } = body;
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
@@ -36,7 +36,7 @@ export async function POST(request: Request) {
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
gpu_id,
|
||||
gpu_ids,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
@@ -46,7 +46,7 @@ export async function POST(request: Request) {
|
||||
const training = await prisma.job.create({
|
||||
data: {
|
||||
name,
|
||||
gpu_id,
|
||||
gpu_ids,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
|
||||
@@ -58,7 +58,7 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
<h2 className="text-lg font-semibold">Job Details</h2>
|
||||
<p className="text-gray-400">ID: {job.id}</p>
|
||||
<p className="text-gray-400">Name: {job.name}</p>
|
||||
<p className="text-gray-400">GPU: {job.gpu_id}</p>
|
||||
<p className="text-gray-400">GPUs: {job.gpu_ids}</p>
|
||||
<p className="text-gray-400">Status: {job.status}</p>
|
||||
<p className="text-gray-400">Info: {job.info}</p>
|
||||
<p className="text-gray-400">Step: {job.step}</p>
|
||||
|
||||
@@ -22,7 +22,7 @@ export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const runId = searchParams.get('id');
|
||||
const [gpuID, setGpuID] = useState<number | null>(null);
|
||||
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
|
||||
const { settings, isSettingsLoaded } = useSettings();
|
||||
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
|
||||
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
||||
@@ -52,7 +52,7 @@ export default function TrainingForm() {
|
||||
fetch(`/api/jobs?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setGpuID(data.gpu_id);
|
||||
setGpuIDs(data.gpu_ids);
|
||||
setJobConfig(JSON.parse(data.job_config));
|
||||
})
|
||||
.catch(error => console.error('Error fetching training:', error));
|
||||
@@ -61,8 +61,8 @@ export default function TrainingForm() {
|
||||
|
||||
useEffect(() => {
|
||||
if (isGPUInfoLoaded) {
|
||||
if (gpuID === null && gpuList.length > 0) {
|
||||
setGpuID(gpuList[0]);
|
||||
if (gpuIDs === null && gpuList.length > 0) {
|
||||
setGpuIDs(`${gpuList[0]}`);
|
||||
}
|
||||
}
|
||||
}, [gpuList, isGPUInfoLoaded]);
|
||||
@@ -73,8 +73,8 @@ export default function TrainingForm() {
|
||||
}
|
||||
}, [settings, isSettingsLoaded]);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
const saveJob = async () => {
|
||||
if (status === 'saving') return;
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
@@ -86,7 +86,7 @@ export default function TrainingForm() {
|
||||
body: JSON.stringify({
|
||||
id: runId,
|
||||
name: jobConfig.config.name,
|
||||
gpu_id: gpuID,
|
||||
gpu_ids: gpuIDs,
|
||||
job_config: jobConfig,
|
||||
}),
|
||||
});
|
||||
@@ -106,6 +106,11 @@ export default function TrainingForm() {
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
saveJob();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<TopBar>
|
||||
@@ -118,6 +123,15 @@ export default function TrainingForm() {
|
||||
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Button
|
||||
className="text-gray-200 bg-green-800 px-3 py-1 rounded-md"
|
||||
onClick={() => saveJob()}
|
||||
disabled={status === 'saving'}
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : runId ? 'Update Job' : 'Create Job'}
|
||||
</Button>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
@@ -132,9 +146,9 @@ export default function TrainingForm() {
|
||||
/>
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuID}`}
|
||||
value={`${gpuIDs}`}
|
||||
className="pt-2"
|
||||
onChange={value => setGpuID(parseInt(value))}
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map(gpu => ({ value: `${gpu}`, label: `GPU #${gpu}` }))}
|
||||
/>
|
||||
</Card>
|
||||
@@ -553,17 +567,10 @@ export default function TrainingForm() {
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={status === 'saving'}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : runId ? 'Update Training' : 'Create Training'}
|
||||
</button>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
|
||||
</form>
|
||||
<div className="pt-20"></div>
|
||||
</MainContent>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -38,7 +38,7 @@ export default function JobsTable(props: JobsTableProps) {
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{jobs.map((job, index) => {
|
||||
{jobs?.map((job, index) => {
|
||||
const jobConfig: JobConfig = JSON.parse(job.job_config);
|
||||
const totalSteps = jobConfig.config.process[0].train.steps;
|
||||
|
||||
@@ -68,7 +68,7 @@ export default function JobsTable(props: JobsTableProps) {
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-3 py-2">{job.gpu_id}</td>
|
||||
<td className="px-3 py-2">{job.gpu_ids}</td>
|
||||
<td className={`px-3 py-2 ${statusClass}`}>{job.status}</td>
|
||||
<td className="px-3 py-2 truncate max-w-xs">{job.info}</td>
|
||||
</tr>
|
||||
|
||||
@@ -13,8 +13,13 @@ export default function useJobsList() {
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Jobs:', data);
|
||||
setJobs(data);
|
||||
setStatus('success');
|
||||
if (data.error) {
|
||||
console.log('Error fetching jobs:', data.error);
|
||||
setStatus('error');
|
||||
} else {
|
||||
setJobs(data.jobs);
|
||||
setStatus('success');
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
|
||||
Reference in New Issue
Block a user