diff --git a/ui/src/app/datasets/page.tsx b/ui/src/app/datasets/page.tsx index 3d1d7a2b..b5304c04 100644 --- a/ui/src/app/datasets/page.tsx +++ b/ui/src/app/datasets/page.tsx @@ -5,32 +5,12 @@ import Card from '@/components/Card'; import { Modal } from '@/components/Modal'; import Link from 'next/link'; import { TextInput } from '@/components/formInputs'; +import useDatasetList from '@/hooks/useDatasetList'; export default function Datasets() { - const [datasets, setDatasets] = useState([]); + const { datasets, status, refreshDatasets } = useDatasetList(); const [newDatasetName, setNewDatasetName] = useState(''); const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false); - const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); - - const refreshDatasets = () => { - setStatus('loading'); - fetch('/api/datasets/list') - .then(res => res.json()) - .then(data => { - console.log('Datasets:', data); - // sort - data.sort((a: string, b: string) => a.localeCompare(b)); - setDatasets(data); - setStatus('success'); - }) - .catch(error => { - console.error('Error fetching datasets:', error); - setStatus('error'); - }); - }; - useEffect(() => { - refreshDatasets(); - }, []); return ( <>
diff --git a/ui/src/app/settings/page.tsx b/ui/src/app/settings/page.tsx index 9e7a7097..3b50a282 100644 --- a/ui/src/app/settings/page.tsx +++ b/ui/src/app/settings/page.tsx @@ -1,29 +1,12 @@ 'use client'; import { useEffect, useState } from 'react'; +import useSettings from '@/hooks/useSettings'; export default function Settings() { - const [settings, setSettings] = useState({ - HF_TOKEN: '', - TRAINING_FOLDER: '', - DATASETS_FOLDER: '', - }); + const { settings, setSettings } = useSettings(); const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); - useEffect(() => { - // Fetch current settings - fetch('/api/settings') - .then(res => res.json()) - .then(data => { - setSettings({ - HF_TOKEN: data.HF_TOKEN || '', - TRAINING_FOLDER: data.TRAINING_FOLDER || '', - DATASETS_FOLDER: data.DATASETS_FOLDER || '', - }); - }) - .catch(error => console.error('Error fetching settings:', error)); - }, []); - const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); setStatus('saving'); diff --git a/ui/src/app/train/page.tsx b/ui/src/app/train/page.tsx index 7298b1f1..f7936b6a 100644 --- a/ui/src/app/train/page.tsx +++ b/ui/src/app/train/page.tsx @@ -1,28 +1,50 @@ 'use client'; // todo update training folder from settings -import { useEffect, useMemo, useState } from 'react'; +import { use, useEffect, useMemo, useState } from 'react'; import { useSearchParams, useRouter } from 'next/navigation'; import { options } from './options'; -import { GPUApiResponse } from '@/types'; import { defaultJobConfig, defaultDatasetConfig } from './jobConfig'; import { JobConfig } from '@/types'; import { objectCopy } from '@/utils/basic'; import { useNestedState } from '@/utils/hooks'; import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; import Card from '@/components/Card'; -import { Trash, X } from 'lucide-react'; +import { X } from 'lucide-react'; +import useSettings from '@/hooks/useSettings'; +import useGPUInfo from '@/hooks/useGPUInfo'; +import useDatasetList from '@/hooks/useDatasetList'; +import path from 'path'; export default function TrainingForm() { const router = useRouter(); const searchParams = useSearchParams(); const runId = searchParams.get('id'); const [gpuID, setGpuID] = useState(null); - const [gpuList, setGpuList] = useState([]); + const { settings, isSettingsLoaded } = useSettings(); + const { gpuList, isGPUInfoLoaded } = useGPUInfo(); + const { datasets, status: datasetFetchStatus } = useDatasetList(); + const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]); const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); + useEffect(() => { + if (!isSettingsLoaded) return; + if (datasetFetchStatus !== 'success') return; + + const datasetOptions = datasets.map(name => ({ value: path.join(settings.DATASETS_FOLDER, name), label: name })); + setDatasetOptions(datasetOptions); + const defaultDatasetPath = defaultDatasetConfig.folder_path; + + for(let i = 0; i < jobConfig.config.process[0].datasets.length; i++) { + const dataset = jobConfig.config.process[0].datasets[i]; + if (dataset.folder_path === defaultDatasetPath) { + setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`); + } + } + }, [datasets, settings, isSettingsLoaded, datasetFetchStatus]); + useEffect(() => { if (runId) { fetch(`/api/training?id=${runId}`) @@ -36,33 +58,18 @@ export default function TrainingForm() { }, [runId]); useEffect(() => { - const fetchGpuInfo = async () => { - try { - const response = await fetch('/api/gpu'); - - if (!response.ok) { - throw new Error(`HTTP error! Status: ${response.status}`); - } - - const data: GPUApiResponse = await response.json(); - setGpuList(data.gpus.map(gpu => gpu.index).sort()); - } catch (err) { - console.log(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`); - } finally { - // setLoading(false); + if (isGPUInfoLoaded) { + if (gpuID === null && gpuList.length > 0) { + setGpuID(gpuList[0]); } - }; + } + }, [gpuList, isGPUInfoLoaded]); - fetch('/api/settings') - .then(res => res.json()) - .then(data => { - setJobConfig(data.TRAINING_FOLDER, 'config.process[0].training_folder'); - }) - .catch(error => console.error('Error fetching settings:', error)); - - // Fetch immediately on component mount - fetchGpuInfo(); - }, []); + useEffect(() => { + if (isSettingsLoaded) { + setJobConfig(settings.TRAINING_FOLDER, 'config.process[0].training_folder'); + } + }, [settings, isSettingsLoaded]); const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); @@ -296,14 +303,20 @@ export default function TrainingForm() {

Dataset {i + 1}

- setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} + options={datasetOptions} + /> + {/* setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} placeholder="eg. /path/to/images/folder" required - /> - */} + {/* + /> */}
([]); + const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle'); + + const refreshDatasets = () => { + setStatus('loading'); + fetch('/api/datasets/list') + .then(res => res.json()) + .then(data => { + console.log('Datasets:', data); + // sort + data.sort((a: string, b: string) => a.localeCompare(b)); + setDatasets(data); + setStatus('success'); + }) + .catch(error => { + console.error('Error fetching datasets:', error); + setStatus('error'); + }); + }; + useEffect(() => { + refreshDatasets(); + }, []); + + return { datasets, setDatasets, status, refreshDatasets }; +} diff --git a/ui/src/hooks/useGPUInfo.tsx b/ui/src/hooks/useGPUInfo.tsx new file mode 100644 index 00000000..dd7e32fe --- /dev/null +++ b/ui/src/hooks/useGPUInfo.tsx @@ -0,0 +1,32 @@ +'use client'; + +import { GPUApiResponse } from '@/types'; +import { useEffect, useState } from 'react'; + +export default function useGPUInfo() { + const [gpuList, setGpuList] = useState([]); + const [isGPUInfoLoaded, setIsLoaded] = useState(false); + useEffect(() => { + const fetchGpuInfo = async () => { + try { + const response = await fetch('/api/gpu'); + + if (!response.ok) { + throw new Error(`HTTP error! Status: ${response.status}`); + } + + const data: GPUApiResponse = await response.json(); + setGpuList(data.gpus.map(gpu => gpu.index).sort()); + } catch (err) { + console.log(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`); + } finally { + setIsLoaded(true); + } + }; + + // Fetch immediately on component mount + fetchGpuInfo(); + }, []); + + return { gpuList, setGpuList, isGPUInfoLoaded }; +} diff --git a/ui/src/hooks/useSettings.tsx b/ui/src/hooks/useSettings.tsx new file mode 100644 index 00000000..15b5175e --- /dev/null +++ b/ui/src/hooks/useSettings.tsx @@ -0,0 +1,28 @@ +'use client'; + +import { useEffect, useState } from 'react'; + +export default function useSettings() { + const [settings, setSettings] = useState({ + HF_TOKEN: '', + TRAINING_FOLDER: '', + DATASETS_FOLDER: '', + }); + const [isSettingsLoaded, setIsLoaded] = useState(false); + useEffect(() => { + // Fetch current settings + fetch('/api/settings') + .then(res => res.json()) + .then(data => { + setSettings({ + HF_TOKEN: data.HF_TOKEN || '', + TRAINING_FOLDER: data.TRAINING_FOLDER || '', + DATASETS_FOLDER: data.DATASETS_FOLDER || '', + }); + setIsLoaded(true); + }) + .catch(error => console.error('Error fetching settings:', error)); + }, []); + + return { settings, setSettings, isSettingsLoaded }; +}