mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
210 lines
6.9 KiB
TypeScript
210 lines
6.9 KiB
TypeScript
'use client';
|
|
|
|
import { useEffect, useState } from 'react';
|
|
import { useSearchParams, useRouter } from 'next/navigation';
|
|
import { options, modelArchs, isVideoModelFromArch } from './options';
|
|
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
|
|
import { JobConfig } from '@/types';
|
|
import { objectCopy } from '@/utils/basic';
|
|
import { useNestedState } from '@/utils/hooks';
|
|
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
|
import Card from '@/components/Card';
|
|
import { X } from 'lucide-react';
|
|
import useSettings from '@/hooks/useSettings';
|
|
import useGPUInfo from '@/hooks/useGPUInfo';
|
|
import useDatasetList from '@/hooks/useDatasetList';
|
|
import path from 'path';
|
|
import { TopBar, MainContent } from '@/components/layout';
|
|
import { Button } from '@headlessui/react';
|
|
import { FaChevronLeft } from 'react-icons/fa';
|
|
import SimpleJob from './SimpleJob';
|
|
import AdvancedJob from './AdvancedJob';
|
|
import ErrorBoundary from '@/components/ErrorBoundary';
|
|
import { apiClient } from '@/utils/api';
|
|
|
|
const isDev = process.env.NODE_ENV === 'development';
|
|
|
|
export default function TrainingForm() {
|
|
const router = useRouter();
|
|
const searchParams = useSearchParams();
|
|
const runId = searchParams.get('id');
|
|
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
|
|
const { settings, isSettingsLoaded } = useSettings();
|
|
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
|
|
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
|
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
|
|
const [showAdvancedView, setShowAdvancedView] = useState(false);
|
|
|
|
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
|
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
|
|
|
useEffect(() => {
|
|
if (!isSettingsLoaded) return;
|
|
if (datasetFetchStatus !== 'success') return;
|
|
|
|
const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name }));
|
|
setDatasetOptions(datasetOptions);
|
|
const defaultDatasetPath = defaultDatasetConfig.folder_path;
|
|
|
|
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
|
const dataset = jobConfig.config.process[0].datasets[i];
|
|
if (dataset.folder_path === defaultDatasetPath) {
|
|
if (datasetOptions.length > 0) {
|
|
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
|
|
}
|
|
}
|
|
}
|
|
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
|
|
|
useEffect(() => {
|
|
if (runId) {
|
|
apiClient
|
|
.get(`/api/jobs?id=${runId}`)
|
|
.then(res => res.data)
|
|
.then(data => {
|
|
console.log('Training:', data);
|
|
setGpuIDs(data.gpu_ids);
|
|
setJobConfig(JSON.parse(data.job_config));
|
|
})
|
|
.catch(error => console.error('Error fetching training:', error));
|
|
}
|
|
}, [runId]);
|
|
|
|
useEffect(() => {
|
|
if (isGPUInfoLoaded) {
|
|
if (gpuIDs === null && gpuList.length > 0) {
|
|
setGpuIDs(`${gpuList[0].index}`);
|
|
}
|
|
}
|
|
}, [gpuList, isGPUInfoLoaded]);
|
|
|
|
useEffect(() => {
|
|
if (isSettingsLoaded) {
|
|
setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder');
|
|
}
|
|
}, [settings, isSettingsLoaded]);
|
|
|
|
const saveJob = async () => {
|
|
if (status === 'saving') return;
|
|
setStatus('saving');
|
|
|
|
apiClient
|
|
.post('/api/jobs', {
|
|
id: runId,
|
|
name: jobConfig.config.name,
|
|
gpu_ids: gpuIDs,
|
|
job_config: jobConfig,
|
|
})
|
|
.then(res => {
|
|
setStatus('success');
|
|
if (runId) {
|
|
router.push(`/jobs/${runId}`);
|
|
} else {
|
|
router.push(`/jobs/${res.data.id}`);
|
|
}
|
|
})
|
|
.catch(error => {
|
|
if (error.response?.status === 409) {
|
|
alert('Training name already exists. Please choose a different name.');
|
|
} else {
|
|
alert('Failed to save job. Please try again.');
|
|
}
|
|
console.log('Error saving training:', error);
|
|
})
|
|
.finally(() =>
|
|
setTimeout(() => {
|
|
setStatus('idle');
|
|
}, 2000),
|
|
);
|
|
};
|
|
|
|
const handleSubmit = async (e: React.FormEvent) => {
|
|
e.preventDefault();
|
|
saveJob();
|
|
};
|
|
|
|
return (
|
|
<>
|
|
<TopBar>
|
|
<div>
|
|
<Button className="text-gray-500 dark:text-gray-300 px-3 mt-1" onClick={() => history.back()}>
|
|
<FaChevronLeft />
|
|
</Button>
|
|
</div>
|
|
<div>
|
|
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
|
</div>
|
|
<div className="flex-1"></div>
|
|
{showAdvancedView && (
|
|
<>
|
|
<div>
|
|
<SelectInput
|
|
value={`${gpuIDs}`}
|
|
onChange={value => setGpuIDs(value)}
|
|
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
|
/>
|
|
</div>
|
|
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
|
</>
|
|
)}
|
|
|
|
<div className="pr-2">
|
|
<Button
|
|
className="text-gray-200 bg-gray-800 px-3 py-1 rounded-md"
|
|
onClick={() => setShowAdvancedView(!showAdvancedView)}
|
|
>
|
|
{showAdvancedView ? 'Show Simple' : 'Show Advanced'}
|
|
</Button>
|
|
</div>
|
|
<div>
|
|
<Button
|
|
className="text-gray-200 bg-green-800 px-3 py-1 rounded-md"
|
|
onClick={() => saveJob()}
|
|
disabled={status === 'saving'}
|
|
>
|
|
{status === 'saving' ? 'Saving...' : runId ? 'Update Job' : 'Create Job'}
|
|
</Button>
|
|
</div>
|
|
</TopBar>
|
|
|
|
{showAdvancedView ? (
|
|
<div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto">
|
|
<AdvancedJob
|
|
jobConfig={jobConfig}
|
|
setJobConfig={setJobConfig}
|
|
status={status}
|
|
handleSubmit={handleSubmit}
|
|
runId={runId}
|
|
gpuIDs={gpuIDs}
|
|
setGpuIDs={setGpuIDs}
|
|
gpuList={gpuList}
|
|
datasetOptions={datasetOptions}
|
|
settings={settings}
|
|
/>
|
|
</div>
|
|
) : (
|
|
<MainContent>
|
|
<ErrorBoundary fallback={
|
|
<div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg">
|
|
Advanced job detected. Please switch to advanced view to continue.
|
|
</div>
|
|
}>
|
|
<SimpleJob
|
|
jobConfig={jobConfig}
|
|
setJobConfig={setJobConfig}
|
|
status={status}
|
|
handleSubmit={handleSubmit}
|
|
runId={runId}
|
|
gpuIDs={gpuIDs}
|
|
setGpuIDs={setGpuIDs}
|
|
gpuList={gpuList}
|
|
datasetOptions={datasetOptions}
|
|
/>
|
|
</ErrorBoundary>
|
|
|
|
<div className="pt-20"></div>
|
|
</MainContent>
|
|
)}
|
|
</>
|
|
);
|
|
} |