mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added advanced mode yaml editor to the ui
This commit is contained in:
37
ui/package-lock.json
generated
37
ui/package-lock.json
generated
@@ -9,6 +9,7 @@
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@monaco-editor/react": "^4.7.0",
|
||||
"@prisma/client": "^6.3.1",
|
||||
"axios": "^1.7.9",
|
||||
"classnames": "^2.5.1",
|
||||
@@ -21,7 +22,8 @@
|
||||
"react-dropzone": "^14.3.5",
|
||||
"react-global-hooks": "^1.3.5",
|
||||
"react-icons": "^5.5.0",
|
||||
"sqlite3": "^5.1.7"
|
||||
"sqlite3": "^5.1.7",
|
||||
"yaml": "^2.7.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
@@ -535,6 +537,27 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@monaco-editor/loader": {
|
||||
"version": "1.5.0",
|
||||
"resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.5.0.tgz",
|
||||
"integrity": "sha512-hKoGSM+7aAc7eRTRjpqAZucPmoNOC4UUbknb/VNoTkEIkCPhqV8LfbsgM1webRM7S/z21eHEx9Fkwx8Z/C/+Xw==",
|
||||
"dependencies": {
|
||||
"state-local": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/@monaco-editor/react": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@monaco-editor/react/-/react-4.7.0.tgz",
|
||||
"integrity": "sha512-cyzXQCtO47ydzxpQtCGSQGOC8Gk3ZUeBXFAxD+CWXYFo5OqZyZUonFl0DwUlTyAfRHntBfw2p3w4s9R6oe1eCA==",
|
||||
"dependencies": {
|
||||
"@monaco-editor/loader": "^1.5.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"monaco-editor": ">= 0.25.0 < 1",
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0",
|
||||
"react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "15.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.1.7.tgz",
|
||||
@@ -2751,6 +2774,12 @@
|
||||
"integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/monaco-editor": {
|
||||
"version": "0.52.2",
|
||||
"resolved": "https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.2.tgz",
|
||||
"integrity": "sha512-GEQWEZmfkOGLdd3XK8ryrfWz3AIP8YymVXiPHEdewrUq7mh0qrKrfHLNCXcbB6sTnMLnOZ3ztSiKcciFUkIJwQ==",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/ms": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz",
|
||||
@@ -3947,6 +3976,11 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/state-local": {
|
||||
"version": "1.0.7",
|
||||
"resolved": "https://registry.npmjs.org/state-local/-/state-local-1.0.7.tgz",
|
||||
"integrity": "sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w=="
|
||||
},
|
||||
"node_modules/streamsearch": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/streamsearch/-/streamsearch-1.1.0.tgz",
|
||||
@@ -4504,7 +4538,6 @@
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/yaml/-/yaml-2.7.0.tgz",
|
||||
"integrity": "sha512-+hSoy/QHluxmC9kCIJyL/uyFmLmc+e5CFR5Wa+bpIhIj85LVb9ZH2nVnqrHoSvKogwODv0ClqZkmiSSaIH5LTA==",
|
||||
"dev": true,
|
||||
"bin": {
|
||||
"yaml": "bin.mjs"
|
||||
},
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@monaco-editor/react": "^4.7.0",
|
||||
"@prisma/client": "^6.3.1",
|
||||
"axios": "^1.7.9",
|
||||
"classnames": "^2.5.1",
|
||||
@@ -24,7 +25,8 @@
|
||||
"react-dropzone": "^14.3.5",
|
||||
"react-global-hooks": "^1.3.5",
|
||||
"react-icons": "^5.5.0",
|
||||
"sqlite3": "^5.1.7"
|
||||
"sqlite3": "^5.1.7",
|
||||
"yaml": "^2.7.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
|
||||
@@ -22,13 +22,13 @@ export const defaultJobConfig: JobConfig = {
|
||||
type: 'ui_trainer',
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda:0',
|
||||
device: 'cuda',
|
||||
trigger_word: null,
|
||||
performance_log_every: 10,
|
||||
network: {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
linear_alpha: 16,
|
||||
linear: 32,
|
||||
linear_alpha: 32,
|
||||
lokr_full_rank: true,
|
||||
lokr_factor: -1
|
||||
},
|
||||
|
||||
@@ -17,6 +17,8 @@ import path from 'path';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import SimpleJob from './SimpleJob';
|
||||
import AdvancedJob from './AdvancedJob';
|
||||
|
||||
const isDev = process.env.NODE_ENV === 'development';
|
||||
|
||||
@@ -29,12 +31,11 @@ export default function TrainingForm() {
|
||||
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
|
||||
const { datasets, status: datasetFetchStatus } = useDatasetList();
|
||||
const [datasetOptions, setDatasetOptions] = useState<{ value: string; label: string }[]>([]);
|
||||
const [showAdvancedView, setShowAdvancedView] = useState(false);
|
||||
|
||||
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSettingsLoaded) return;
|
||||
if (datasetFetchStatus !== 'success') return;
|
||||
@@ -130,6 +131,27 @@ export default function TrainingForm() {
|
||||
<h1 className="text-lg">{runId ? 'Edit Training Job' : 'New Training Job'}</h1>
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
{showAdvancedView && (
|
||||
<>
|
||||
<div>
|
||||
<SelectInput
|
||||
value={`${gpuIDs}`}
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
</div>
|
||||
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="pr-2">
|
||||
<Button
|
||||
className="text-gray-200 bg-gray-800 px-3 py-1 rounded-md"
|
||||
onClick={() => setShowAdvancedView(!showAdvancedView)}
|
||||
>
|
||||
{showAdvancedView ? 'Show Simple' : 'Show Advanced'}
|
||||
</Button>
|
||||
</div>
|
||||
<div>
|
||||
<Button
|
||||
className="text-gray-200 bg-green-800 px-3 py-1 rounded-md"
|
||||
@@ -140,601 +162,39 @@ export default function TrainingForm() {
|
||||
</Button>
|
||||
</div>
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<Card title="Job Settings">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
onChange={value => setJobConfig(value, 'config.name')}
|
||||
placeholder="Enter training name"
|
||||
disabled={runId !== null}
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuIDs}`}
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map(gpu => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<Card title="Model Configuration">
|
||||
<SelectInput
|
||||
label="Name or Path"
|
||||
value={jobConfig.config.process[0].model.name_or_path}
|
||||
onChange={value => {
|
||||
// see if model changed
|
||||
const currentModel = options.model.find(
|
||||
model => model.name_or_path === jobConfig.config.process[0].model.name_or_path,
|
||||
);
|
||||
if (!currentModel || currentModel.name_or_path === value) {
|
||||
// model has not changed
|
||||
return;
|
||||
}
|
||||
// revert defaults from previous model
|
||||
for (const key in currentModel.defaults) {
|
||||
setJobConfig(currentModel.defaults[key][1], key);
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.name_or_path');
|
||||
// update the defaults when a model is selected
|
||||
const model = options.model.find(model => model.name_or_path === value);
|
||||
if (model?.defaults) {
|
||||
for (const key in model.defaults) {
|
||||
setJobConfig(model.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
}}
|
||||
options={
|
||||
options.model
|
||||
.map(model => {
|
||||
if (model.dev_only && !isDev) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
value: model.name_or_path,
|
||||
label: model.name_or_path,
|
||||
};
|
||||
})
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
onChange={value => {
|
||||
const currentArch = modelArchs.find(
|
||||
a => a.name === jobConfig.config.process[0].model.arch,
|
||||
);
|
||||
if (!currentArch || currentArch.name === value) {
|
||||
return;
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
}}
|
||||
options={
|
||||
modelArchs
|
||||
.map(model => {
|
||||
return {
|
||||
value: model.name,
|
||||
label: model.label,
|
||||
};
|
||||
})
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
/>
|
||||
<FormGroup label="Quantize">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
<Card title="Target Configuration">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].network.type')}
|
||||
options={[
|
||||
{ value: 'lora', label: 'LoRA' },
|
||||
{ value: 'lokr', label: 'LoKr' },
|
||||
]}
|
||||
/>
|
||||
{jobConfig.config.process[0].network?.type == 'lokr' && (
|
||||
<SelectInput
|
||||
label="LoKr Factor"
|
||||
value={ `${jobConfig.config.process[0].network?.lokr_factor ?? -1}`}
|
||||
onChange={value => setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')}
|
||||
options={[
|
||||
{ value: '-1', label: 'Auto' },
|
||||
{ value: '4', label: '4' },
|
||||
{ value: '8', label: '8' },
|
||||
{ value: '16', label: '16' },
|
||||
{ value: '32', label: '32' },
|
||||
]}
|
||||
/>
|
||||
)}
|
||||
{jobConfig.config.process[0].network?.type == 'lora' && (
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
onChange={value => {
|
||||
console.log('onChange', value);
|
||||
setJobConfig(value, 'config.process[0].network.linear');
|
||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={0}
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
value={jobConfig.config.process[0].save.dtype}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
|
||||
options={[
|
||||
{ value: 'bf16', label: 'BF16' },
|
||||
{ value: 'fp16', label: 'FP16' },
|
||||
{ value: 'fp32', label: 'FP32' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Save Every"
|
||||
value={jobConfig.config.process[0].save.save_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Max Step Saves to Keep"
|
||||
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
value={jobConfig.config.process[0].train.batch_size}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Gradient Accumulation"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.gradient_accumulation}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
|
||||
placeholder="eg. 1"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Steps"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
|
||||
placeholder="eg. 2000"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Optimizer"
|
||||
value={jobConfig.config.process[0].train.optimizer}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
|
||||
options={[
|
||||
{ value: 'adamw8bit', label: 'AdamW8Bit' },
|
||||
{ value: 'adafactor', label: 'Adafactor' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Learning Rate"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.lr}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Weight Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
{ value: 'linear', label: 'Linear' },
|
||||
{ value: 'flux_shift', label: 'Flux Shift' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Timestep Bias"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.content_or_style}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.content_or_style')}
|
||||
options={[
|
||||
{ value: 'balanced', label: 'Balanced' },
|
||||
{ value: 'content', label: 'High Noise' },
|
||||
{ value: 'style', label: 'Low Noise' },
|
||||
]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Noise Scheduler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.noise_scheduler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.noise_scheduler')}
|
||||
options={[
|
||||
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="EMA (Exponential Moving Average)">
|
||||
<Checkbox
|
||||
label="Use EMA"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
<FormGroup label="Unload Text Encoder" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differtial Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="DFE Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DFE Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Datasets">
|
||||
<>
|
||||
{jobConfig.config.process[0].datasets.map((dataset, i) => (
|
||||
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Dataset"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
/>
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
className="pt-2"
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].network_weight`)}
|
||||
placeholder="eg. 1.0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<TextInput
|
||||
label="Default Caption"
|
||||
value={dataset.default_caption}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
|
||||
placeholder="eg. A photo of a cat"
|
||||
/>
|
||||
<NumberInput
|
||||
label="Caption Dropout Rate"
|
||||
className="pt-2"
|
||||
value={dataset.caption_dropout_rate}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)
|
||||
}
|
||||
placeholder="eg. 0.05"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Settings" className="">
|
||||
<Checkbox
|
||||
label="Cache Latents"
|
||||
checked={dataset.cache_latents_to_disk || false}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
|
||||
}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Is Regularization"
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Resolutions" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{[
|
||||
[256, 512, 768],
|
||||
[1024, 1280, 1536],
|
||||
].map(resGroup => (
|
||||
<div key={resGroup[0]} className="space-y-2">
|
||||
{resGroup.map(res => (
|
||||
<Checkbox
|
||||
key={res}
|
||||
label={res.toString()}
|
||||
checked={dataset.resolution.includes(res)}
|
||||
onChange={value => {
|
||||
const resolutions = dataset.resolution.includes(res)
|
||||
? dataset.resolution.filter(r => r !== res)
|
||||
: [...dataset.resolution, res];
|
||||
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Dataset
|
||||
</button>
|
||||
</>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<div className={isVideoModel ? "grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6" : "grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6"}>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Sample Every"
|
||||
value={jobConfig.config.process[0].sample.sample_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="Sampler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].sample.sampler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
||||
options={[
|
||||
{ value: 'flowmatch', label: 'FlowMatch' },
|
||||
{ value: 'ddpm', label: 'DDPM' },
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Guidance Scale"
|
||||
value={jobConfig.config.process[0].sample.guidance_scale}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Sample Steps"
|
||||
value={jobConfig.config.process[0].sample.sample_steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
|
||||
placeholder="eg. 1"
|
||||
className="pt-2"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Width"
|
||||
value={jobConfig.config.process[0].sample.width}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
||||
placeholder="eg. 1024"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Height"
|
||||
value={jobConfig.config.process[0].sample.height}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
||||
placeholder="eg. 1024"
|
||||
className="pt-2"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
{showAdvancedView ? (
|
||||
<div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto">
|
||||
<AdvancedJob
|
||||
jobConfig={jobConfig}
|
||||
setJobConfig={setJobConfig}
|
||||
status={status}
|
||||
handleSubmit={handleSubmit}
|
||||
runId={runId}
|
||||
gpuIDs={gpuIDs}
|
||||
setGpuIDs={setGpuIDs}
|
||||
gpuList={gpuList}
|
||||
datasetOptions={datasetOptions}
|
||||
settings={settings}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<MainContent>
|
||||
<SimpleJob
|
||||
jobConfig={jobConfig}
|
||||
setJobConfig={setJobConfig}
|
||||
status={status}
|
||||
handleSubmit={handleSubmit}
|
||||
runId={runId}
|
||||
gpuIDs={gpuIDs}
|
||||
setGpuIDs={setGpuIDs}
|
||||
gpuList={gpuList}
|
||||
datasetOptions={datasetOptions}
|
||||
/>
|
||||
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Seed"
|
||||
value={jobConfig.config.process[0].sample.seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<Checkbox
|
||||
label="Walk Seed"
|
||||
className="pt-4 pl-2"
|
||||
checked={jobConfig.config.process[0].sample.walk_seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
{ isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<FormGroup
|
||||
label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`}
|
||||
className="pt-2"
|
||||
>
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
<TextInput
|
||||
value={prompt}
|
||||
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
|
||||
placeholder="Enter prompt"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
|
||||
'config.process[0].sample.prompts',
|
||||
)
|
||||
}
|
||||
className="rounded-full p-1 text-sm"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].sample.prompts, ''],
|
||||
'config.process[0].sample.prompts',
|
||||
)
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Prompt
|
||||
</button>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
|
||||
</form>
|
||||
<div className="pt-20"></div>
|
||||
</MainContent>
|
||||
<div className="pt-20"></div>
|
||||
</MainContent>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,12 @@
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export interface Settings {
|
||||
HF_TOKEN: string;
|
||||
TRAINING_FOLDER: string;
|
||||
DATASETS_FOLDER: string;
|
||||
}
|
||||
|
||||
export default function useSettings() {
|
||||
const [settings, setSettings] = useState({
|
||||
HF_TOKEN: '',
|
||||
|
||||
Reference in New Issue
Block a user