Config ui section is coming along

This commit is contained in:
Jaret Burkett
2025-02-19 07:52:24 -07:00
parent b13fcc1039
commit cef7d9e594
17 changed files with 1323 additions and 120 deletions

6
ui/package-lock.json generated
View File

@@ -9,6 +9,7 @@
"version": "0.1.0",
"dependencies": {
"@prisma/client": "^6.3.1",
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.7",
"prisma": "^6.3.1",
@@ -1182,6 +1183,11 @@
"node": ">=10"
}
},
"node_modules/classnames": {
"version": "2.5.1",
"resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz",
"integrity": "sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow=="
},
"node_modules/clean-stack": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz",

View File

@@ -11,6 +11,7 @@
},
"dependencies": {
"@prisma/client": "^6.3.1",
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.7",
"prisma": "^6.3.1",

View File

@@ -17,7 +17,8 @@ model Settings {
model Training {
id String @id @default(uuid())
name String
run_data String // JSON string
gpu_id Int
job_config String // JSON string
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}

102
ui/src/app/api/gpu/route.ts Normal file
View File

@@ -0,0 +1,102 @@
import { NextResponse } from 'next/server';
import { exec } from 'child_process';
import { promisify } from 'util';
const execAsync = promisify(exec);
export async function GET() {
try {
// Check if nvidia-smi is available
const hasNvidiaSmi = await checkNvidiaSmi();
if (!hasNvidiaSmi) {
return NextResponse.json({
hasNvidiaSmi: false,
gpus: [],
error: 'nvidia-smi not found or not accessible',
});
}
// Get GPU stats
const gpuStats = await getGpuStats();
return NextResponse.json({
hasNvidiaSmi: true,
gpus: gpuStats,
});
} catch (error) {
console.error('Error fetching NVIDIA GPU stats:', error);
return NextResponse.json(
{
hasNvidiaSmi: false,
gpus: [],
error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
},
{ status: 500 },
);
}
}
async function checkNvidiaSmi(): Promise<boolean> {
try {
await execAsync('which nvidia-smi');
return true;
} catch (error) {
return false;
}
}
async function getGpuStats() {
// Get detailed GPU information in JSON format
const { stdout } = await execAsync(
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory --format=csv,noheader,nounits',
);
// Parse CSV output
const gpus = stdout
.trim()
.split('\n')
.map(line => {
const [
index,
name,
driverVersion,
temperature,
gpuUtil,
memoryUtil,
memoryTotal,
memoryFree,
memoryUsed,
powerDraw,
powerLimit,
clockGraphics,
clockMemory,
] = line.split(', ').map(item => item.trim());
return {
index: parseInt(index),
name,
driverVersion,
temperature: parseInt(temperature),
utilization: {
gpu: parseInt(gpuUtil),
memory: parseInt(memoryUtil),
},
memory: {
total: parseInt(memoryTotal),
free: parseInt(memoryFree),
used: parseInt(memoryUsed),
},
power: {
draw: parseFloat(powerDraw),
limit: parseFloat(powerLimit),
},
clocks: {
graphics: parseInt(clockGraphics),
memory: parseInt(clockMemory),
},
};
});
return gpus;
}

View File

@@ -1,12 +1,21 @@
import { NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { defaultTrainFolder } from '@/paths';
const prisma = new PrismaClient();
export async function GET() {
try {
const settings = await prisma.settings.findMany();
return NextResponse.json(settings.reduce((acc, curr) => ({ ...acc, [curr.key]: curr.value }), {}));
const settingsObject = settings.reduce((acc: any, setting) => {
acc[setting.key] = setting.value;
return acc;
}, {});
// if TRAINING_FOLDER is not set, use default
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
}
return NextResponse.json(settingsObject);
} catch (error) {
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
}

View File

@@ -27,7 +27,7 @@ export async function GET(request: Request) {
export async function POST(request: Request) {
try {
const body = await request.json();
const { id, name, run_data } = body;
const { id, name, job_config, gpu_id } = body;
if (id) {
// Update existing training
@@ -35,7 +35,8 @@ export async function POST(request: Request) {
where: { id },
data: {
name,
run_data: JSON.stringify(run_data),
gpu_id,
job_config: JSON.stringify(job_config),
},
});
return NextResponse.json(training);
@@ -44,7 +45,8 @@ export async function POST(request: Request) {
const training = await prisma.training.create({
data: {
name,
run_data: JSON.stringify(run_data),
gpu_id,
job_config: JSON.stringify(job_config),
},
});
return NextResponse.json(training);

View File

@@ -1,15 +1,12 @@
'use client';
import GpuMonitor from '@/components/GPUMonitor';
export default function Dashboard() {
return (
<div className="space-y-6">
<h1 className="text-3xl font-bold">Dashboard</h1>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
{[1, 2, 3].map(i => (
<div key={i} className="p-6 bg-gray-800 rounded-lg">
<h2 className="text-xl font-semibold mb-2">Card {i}</h2>
<p className="text-gray-400">Example dashboard card content</p>
</div>
))}
</div>
<GpuMonitor />
</div>
);
}

View File

@@ -0,0 +1,95 @@
import { JobConfig, DatasetConfig } from '@/types';
export const defaultDatasetConfig: DatasetConfig = {
folder_path: '/path/to/images/folder',
mask_path: null,
mask_min_value: 0.1,
default_caption: '',
caption_ext: 'txt',
caption_dropout_rate: 0.05,
cache_latents_to_disk: false,
is_reg: false,
network_weight: 1,
resolution: [512, 768, 1024],
};
export const defaultJobConfig: JobConfig = {
job: 'extension',
config: {
name: 'my_first_flex_lora_v1',
process: [
{
type: 'sd_trainer',
training_folder: 'output',
device: 'cuda:0',
network: {
type: 'lora',
linear: 16,
linear_alpha: 16,
},
save: {
dtype: 'bf16',
save_every: 250,
max_step_saves_to_keep: 4,
push_to_hub: false,
},
datasets: [
defaultDatasetConfig
],
train: {
batch_size: 1,
bypass_guidance_embedding: true,
steps: 2000,
gradient_accumulation: 1,
train_unet: true,
train_text_encoder: false,
gradient_checkpointing: true,
noise_scheduler: 'flowmatch',
optimizer: 'adamw8bit',
optimizer_params: {
weight_decay: 1e-4
},
lr: 0.0001,
ema_config: {
use_ema: true,
ema_decay: 0.99,
},
dtype: 'bf16',
},
model: {
name_or_path: 'ostris/Flex.1-alpha',
is_flux: true,
quantize: true,
quantize_te: true
},
sample: {
sampler: 'flowmatch',
sample_every: 250,
width: 1024,
height: 1024,
prompts: [
'woman with red hair, playing chess at the park, bomb going off in the background',
'a woman holding a coffee cup, in a beanie, sitting at a cafe',
'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
'a bear building a log cabin in the snow covered mountains',
'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
'hipster man with a beard, building a chair, in a wood shop',
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
"a man holding a sign that says, 'this is a sign'",
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
],
neg: '',
seed: 42,
walk_seed: true,
guidance_scale: 4,
sample_steps: 25,
},
},
],
},
meta: {
name: '[name]',
version: '1.0',
},
};

View File

@@ -1,37 +1,41 @@
export interface Model {
name_or_path: string;
model_kwargs?: Record<string, boolean>;
train_kwargs?: Record<string, boolean>;
name_or_path: string;
defaults?: { [key: string]: any };
}
export interface Option {
model: Model[];
model: Model[];
}
export const options = {
model: [
{
name_or_path: "ostris/Flex.1-alpha",
model_kwargs: {
"is_flux": true
},
train_kwargs: {
"bypass_guidance_embedding": true
}
},
{
name_or_path: "black-forest-labs/FLUX.1-dev",
model_kwargs: {
"is_flux": true
},
},
{
name_or_path: "Alpha-VLLM/Lumina-Image-2.0",
model_kwargs: {
"is_lumina2": true
},
},
]
} as Option;
model: [
{
name_or_path: 'ostris/Flex.1-alpha',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].train.bypass_guidance_embedding': [true, false],
},
},
{
name_or_path: 'black-forest-labs/FLUX.1-dev',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
},
},
{
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_lumina2': [true, false],
},
},
],
} as Option;

View File

@@ -1,34 +1,26 @@
'use client';
// todo update training folder from settings
import { useEffect, useState } from 'react';
import { useEffect, useMemo, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { options } from './options';
interface TrainingData {
modelConfig: {
name_or_path: string;
steps: number;
batchSize: number;
learningRate: number;
};
}
const defaultTrainingData: TrainingData = {
modelConfig: {
name_or_path: 'ostris/Flex.1-alpha',
steps: 100,
batchSize: 32,
learningRate: 0.001,
},
};
import { GPUApiResponse } from '@/types';
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import Card from '@/components/Card';
import { Trash, X } from 'lucide-react';
export default function TrainingForm() {
const router = useRouter();
const searchParams = useSearchParams();
const runId = searchParams.get('id');
const [gpuID, setGpuID] = useState<number | null>(null);
const [gpuList, setGpuList] = useState<number[]>([]);
const [name, setName] = useState('');
const [trainingData, setTrainingData] = useState<TrainingData>(defaultTrainingData);
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
useEffect(() => {
@@ -36,13 +28,42 @@ export default function TrainingForm() {
fetch(`/api/training?id=${runId}`)
.then(res => res.json())
.then(data => {
setName(data.name);
setTrainingData(JSON.parse(data.run_data));
setGpuID(data.gpu_id);
setJobConfig(JSON.parse(data.job_config));
})
.catch(error => console.error('Error fetching training:', error));
}
}, [runId]);
useEffect(() => {
const fetchGpuInfo = async () => {
try {
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
setGpuList(data.gpus.map(gpu => gpu.index).sort());
} catch (err) {
console.log(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
} finally {
// setLoading(false);
}
};
fetch('/api/settings')
.then(res => res.json())
.then(data => {
setJobConfig(data.TRAINING_FOLDER, 'config.process[0].training_folder');
})
.catch(error => console.error('Error fetching settings:', error));
// Fetch immediately on component mount
fetchGpuInfo();
}, []);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setStatus('saving');
@@ -55,8 +76,9 @@ export default function TrainingForm() {
},
body: JSON.stringify({
id: runId,
name,
run_data: trainingData,
name: jobConfig.config.name,
gpu_id: gpuID,
job_config: jobConfig,
}),
});
@@ -75,64 +97,428 @@ export default function TrainingForm() {
}
};
const updateSection = (section: keyof TrainingData, data: any) => {
setTrainingData(prev => ({
...prev,
[section]: { ...prev[section], ...data },
}));
};
const modelOptions = options.model.map(model => model.name_or_path);
return (
<div className="max-w-4xl mx-auto space-y-8 pb-12">
<h1 className="text-3xl font-bold mb-8">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
<div className="space-y-6">
<h1 className="text-xl font-bold mb-8">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
<form onSubmit={handleSubmit} className="space-y-8">
<div className="space-y-4">
<label htmlFor="name" className="block text-sm font-medium mb-2">
Training Name
</label>
<input
type="text"
id="name"
value={name}
onChange={e => setName(e.target.value)}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
placeholder="Enter training name"
required
/>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<Card title="Job Settings">
<TextInput
label="Training Name"
value={jobConfig.config.name}
onChange={value => setJobConfig(value, 'config.name')}
placeholder="Enter training name"
required
/>
<SelectInput
label="GPU ID"
value={`${gpuID}`}
className="pt-2"
onChange={value => setGpuID(parseInt(value))}
options={gpuList.map(gpu => ({ value: `${gpu}`, label: `GPU #${gpu}` }))}
/>
</Card>
{/* Model Configuration Section */}
<section className="space-y-4 p-6 bg-gray-900 rounded-lg">
<h2 className="text-xl font-bold mb-4">Model Configuration</h2>
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium mb-2">Model</label>
<select
value={trainingData.modelConfig.name_or_path}
onChange={e => updateSection('modelConfig', { name_or_path: e.target.value })}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
>
{modelOptions.map(model => (
<option key={model} value={model}>
{model}
</option>
))}
</select>
</div>
<div>
<label className="block text-sm font-medium mb-2">Epochs</label>
<input
type="number"
value={trainingData.modelConfig.steps}
onChange={e => updateSection('modelConfig', { steps: Number(e.target.value) })}
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
{/* Model Configuration Section */}
<Card title="Model Configuration">
<SelectInput
label="Name or Path"
value={jobConfig.config.process[0].model.name_or_path}
onChange={value => {
// see if model changed
const currentModel = options.model.find(
model => model.name_or_path === jobConfig.config.process[0].model.name_or_path,
);
if (!currentModel || currentModel.name_or_path === value) {
// model has not changed
return;
}
// revert defaults from previous model
for (const key in currentModel.defaults) {
setJobConfig(currentModel.defaults[key][1], key);
}
// set new model
setJobConfig(value, 'config.process[0].model.name_or_path');
// update the defaults when a model is selected
const model = options.model.find(model => model.name_or_path === value);
if (model?.defaults) {
for (const key in model.defaults) {
setJobConfig(model.defaults[key][0], key);
}
}
}}
options={options.model.map(model => ({
value: model.name_or_path,
label: model.name_or_path,
}))}
/>
<FormGroup label="Quantize" className="pt-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</FormGroup>
</Card>
{jobConfig.config.process[0].network?.linear && (
<Card title="LoRA Configuration">
<NumberInput
label="Linear Rank"
value={jobConfig.config.process[0].network.linear}
onChange={value => {
setJobConfig(value, 'config.process[0].network.linear');
setJobConfig(value, 'config.process[0].network.linear_alpha');
}}
placeholder="eg. 16"
min={1}
max={1024}
required
/>
</Card>
)}
<Card title="Save Configuration">
<SelectInput
label="Data Type"
value={jobConfig.config.process[0].save.dtype}
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
options={[
{ value: 'bf16', label: 'BF16' },
{ value: 'fp16', label: 'FP16' },
{ value: 'fp32', label: 'FP32' },
]}
/>
<NumberInput
label="Save Every"
value={jobConfig.config.process[0].save.save_every}
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
placeholder="eg. 250"
min={1}
required
/>
<NumberInput
label="Max Step Saves to Keep"
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
placeholder="eg. 4"
min={1}
required
/>
</Card>
</div>
<div>
<Card title="Training Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<NumberInput
label="Batch Size"
className="pt-2"
value={jobConfig.config.process[0].train.batch_size}
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
placeholder="eg. 4"
min={1}
required
/>
<NumberInput
label="Gradient Accumulation"
className="pt-2"
value={jobConfig.config.process[0].train.gradient_accumulation}
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
placeholder="eg. 1"
min={1}
required
/>
<NumberInput
label="Steps"
className="pt-2"
value={jobConfig.config.process[0].train.steps}
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
placeholder="eg. 2000"
min={1}
required
/>
</div>
<div>
<SelectInput
label="Optimizer"
className="pt-2"
value={jobConfig.config.process[0].train.optimizer}
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
options={[
{ value: 'adamw8bit', label: 'AdamW8Bit' },
{ value: 'adafactor', label: 'Adafactor' },
]}
/>
<NumberInput
label="Learning Rate"
className="pt-2"
value={jobConfig.config.process[0].train.lr}
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
placeholder="eg. 0.0001"
min={0}
required
/>
<NumberInput
label="Weight Decay"
className="pt-2"
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
placeholder="eg. 0.0001"
min={0}
required
/>
</div>
</div>
</div>
</section>
</Card>
</div>
<div>
<Card title="Datasets">
<>
{jobConfig.config.process[0].datasets.map((dataset, i) => (
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
<button
type="button"
onClick={() =>
setJobConfig(
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
'config.process[0].datasets',
)
}
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
>
<X />
</button>
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<TextInput
label="Folder Path"
value={dataset.folder_path}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
placeholder="eg. /path/to/images/folder"
required
/>
<TextInput
label="Mask Folder Path"
className="pt-2"
value={dataset.mask_path || ''}
onChange={value => {
let setValue: string | null = value;
if (!setValue || setValue.trim() === '') {
setValue = null;
}
setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`);
}}
placeholder="eg. /path/to/masks/folder"
/>
<NumberInput
label="Mask Min Value"
className="pt-2"
value={dataset.mask_min_value}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)}
placeholder="eg. 0.1"
min={0}
max={1}
required
/>
</div>
<div>
<TextInput
label="Default Caption"
value={dataset.default_caption}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
placeholder="eg. A photo of a cat"
/>
<TextInput
label="Caption Extension"
className="pt-2"
value={dataset.caption_ext}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)}
placeholder="eg. txt"
required
/>
<NumberInput
label="Caption Dropout Rate"
className="pt-2"
value={dataset.caption_dropout_rate}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)}
placeholder="eg. 0.05"
min={0}
required
/>
</div>
<div>
<FormGroup label="Settings" className="">
<Checkbox
label="Cache Latents to Disk"
checked={dataset.cache_latents_to_disk || false}
onChange={value =>
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
}
/>
<Checkbox
label="Is Regularization"
className="pt-2"
checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/>
</FormGroup>
</div>
<div>
<FormGroup label="Resolutions" className="pt-2">
{[256, 512, 768, 1024, 1280].map(res => (
<Checkbox
key={res}
label={res.toString()}
checked={dataset.resolution.includes(res)}
onChange={value => {
const resolutions = dataset.resolution.includes(res)
? dataset.resolution.filter(r => r !== res)
: [...dataset.resolution, res];
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
}}
/>
))}
</FormGroup>
</div>
</div>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig(
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
'config.process[0].datasets',
)
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Dataset
</button>
</>
</Card>
</div>
<div>
<Card title="Sample Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div>
<NumberInput
label="Sample Every"
value={jobConfig.config.process[0].sample.sample_every}
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
placeholder="eg. 250"
min={1}
required
/>
<SelectInput
label="Sampler"
className="pt-2"
value={jobConfig.config.process[0].sample.sampler}
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
/>
</div>
<div>
<NumberInput
label="Guidance Scale"
value={jobConfig.config.process[0].sample.guidance_scale}
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
placeholder="eg. 1.0"
min={0}
required
/>
<NumberInput
label="Sample Steps"
value={jobConfig.config.process[0].sample.sample_steps}
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
placeholder="eg. 1"
className="pt-2"
min={1}
required
/>
</div>
<div>
<NumberInput
label="Width"
value={jobConfig.config.process[0].sample.width}
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
placeholder="eg. 1024"
min={256}
required
/>
<NumberInput
label="Height"
value={jobConfig.config.process[0].sample.height}
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
placeholder="eg. 1024"
className="pt-2"
min={256}
required
/>
</div>
<div>
<NumberInput
label="Seed"
value={jobConfig.config.process[0].sample.seed}
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
placeholder="eg. 0"
min={0}
required
/>
<Checkbox
label="Walk Seed"
className="pt-4 pl-2"
checked={jobConfig.config.process[0].sample.walk_seed}
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
/>
</div>
</div>
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<div className="flex-1">
<TextInput
value={prompt}
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
placeholder="Enter prompt"
required
/>
</div>
<div>
<button
type="button"
onClick={() =>
setJobConfig(
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
'config.process[0].sample.prompts',
)
}
className="rounded-full p-1 text-sm"
>
<X />
</button>
</div>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig([...jobConfig.config.process[0].sample.prompts, ''], 'config.process[0].sample.prompts')
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Prompt
</button>
</FormGroup>
</Card>
</div>
<button
type="submit"

View File

@@ -0,0 +1,15 @@
interface CardProps {
title?: string;
children?: React.ReactNode;
}
const Card: React.FC<CardProps> = ({ title, children }) => {
return (
<section className="space-y-4 p-6 bg-gray-900 rounded-lg">
{title && <h2 className="text-lg mb-4 font-semibold uppercase text-gray-500">{title}</h2>}
{children ? children : null}
</section>
);
};
export default Card;

View File

@@ -0,0 +1,195 @@
// components/GpuMonitor.tsx
import React, { useState, useEffect } from 'react';
import { GPUApiResponse } from '@/types';
const GpuMonitor: React.FC = () => {
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
const [loading, setLoading] = useState<boolean>(true);
const [error, setError] = useState<string | null>(null);
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
useEffect(() => {
const fetchGpuInfo = async () => {
try {
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
setGpuData(data);
setLastUpdated(new Date());
setError(null);
} catch (err) {
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
} finally {
setLoading(false);
}
};
// Fetch immediately on component mount
fetchGpuInfo();
// Set up interval to fetch every 1 seconds
const intervalId = setInterval(fetchGpuInfo, 1000);
// Clean up interval on component unmount
return () => clearInterval(intervalId);
}, []);
// Helper to format memory values
const formatMemory = (mb: number): string => {
if (mb >= 1024) {
return `${(mb / 1024).toFixed(2)} GB`;
}
return `${mb} MB`;
};
// Helper to determine background color based on utilization
const getUtilizationColor = (percent: number): string => {
if (percent < 30) return 'bg-green-100';
if (percent < 70) return 'bg-yellow-100';
return 'bg-red-100';
};
// Helper to determine text color based on utilization
const getUtilizationTextColor = (percent: number): string => {
if (percent < 30) return 'text-green-800';
if (percent < 70) return 'text-yellow-800';
return 'text-red-800';
};
// Helper to determine temperature color
const getTemperatureColor = (temp: number): string => {
if (temp < 50) return 'text-green-600';
if (temp < 80) return 'text-yellow-600';
return 'text-red-600';
};
if (loading) {
return (
<div className="flex justify-center items-center h-64">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
</div>
);
}
if (error) {
return (
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">Error!</strong>
<span className="block sm:inline"> {error}</span>
</div>
);
}
if (!gpuData) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPU data available.</span>
</div>
);
}
if (!gpuData.hasNvidiaSmi) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
</div>
);
}
if (gpuData.gpus.length === 0) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
</div>
);
}
return (
<div className="container mx-auto py-2">
<div className="flex justify-between items-center mb-2">
<h1 className="text-lg font-bold">GPU Monitor</h1>
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3">
{gpuData.gpus.map(gpu => (
<div
key={gpu.index}
className="bg-gray-900 rounded-lg shadow-lg overflow-hidden hover:shadow-xl transition-shadow duration-300 px-2 py-2"
>
<div className="bg-gray-800 text-white px-2 py-1 flex justify-between items-center">
<h2 className="font-bold text-sm truncate">{gpu.name}</h2>
<span className="text-xs bg-gray-700 rounded px-1 py-0.5">GPU #{gpu.index}</span>
</div>
<div className="p-2">
<div className="mb-2 flex items-center">
<p className="text-xs text-gray-500 mr-1">Temperature:</p>
<p className={`text-sm font-bold ${getTemperatureColor(gpu.temperature)}`}>{gpu.temperature}°C</p>
</div>
<div className="">
<p className="text-xs text-gray-600 mb-0.5">GPU Utilization</p>
<div className="w-full bg-gray-500 rounded-full h-1.5">
<div
className={`h-1.5 rounded-full ${gpu.utilization.gpu < 30 ? 'bg-green-500' : gpu.utilization.gpu < 70 ? 'bg-yellow-500' : 'bg-red-500'}`}
style={{ width: `${gpu.utilization.gpu}%` }}
></div>
</div>
<p className="text-right text-xs mt-0.5">{gpu.utilization.gpu}%</p>
</div>
<div className="mb-2">
<p className="text-xs text-gray-600 mb-0.5">Memory Utilization</p>
<div className="w-full bg-gray-500 rounded-full h-1.5">
<div
className="h-1.5 rounded-full bg-blue-500"
style={{ width: `${(gpu.memory.used / gpu.memory.total) * 100}%` }}
></div>
</div>
<div className="flex justify-between text-xs mt-0.5">
<span>
{formatMemory(gpu.memory.used)} / {formatMemory(gpu.memory.total)}
</span>
<span>{((gpu.memory.used / gpu.memory.total) * 100).toFixed(1)}%</span>
</div>
</div>
<div className="grid grid-cols-2 gap-2 mb-2">
<div>
<p className="text-xs text-gray-500 mb-0.5">Power</p>
<p className="text-sm font-medium">
{gpu.power.draw.toFixed(1)}W / {gpu.power.limit.toFixed(1)}W
</p>
</div>
<div>
<p className="text-xs text-gray-500 mb-0.5">Memory Clock</p>
<p className="text-sm font-medium">{gpu.clocks.memory} MHz</p>
</div>
</div>
<div className="mt-1 pt-1 border-t border-gray-600 grid grid-cols-2 gap-2">
<div>
<p className="text-xs text-gray-500 mb-0.5">Graphics Clock</p>
<p className="text-sm font-medium">{gpu.clocks.graphics} MHz</p>
</div>
<div className="">
<p className="text-xs text-gray-500 mb-0.5">Driver Version</p>
<p className="text-sm font-medium">{gpu.driverVersion}</p>
</div>
</div>
</div>
</div>
))}
</div>
</div>
);
};
export default GpuMonitor;

View File

@@ -0,0 +1,150 @@
import React from 'react';
import classNames from 'classnames';
const labelClasses = 'block text-sm mb-2 text-gray-300';
const inputClasses =
'w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent';
export interface InputProps {
label?: string;
className?: string;
placeholder?: string;
required?: boolean;
}
export interface TextInputProps extends InputProps {
value: string;
onChange: (value: string) => void;
type?: 'text' | 'password';
}
export const TextInput = (props: TextInputProps) => {
const { label, value, onChange, placeholder, required } = props;
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<input
type={props.type || 'text'}
value={value}
onChange={e => onChange(e.target.value)}
className={inputClasses}
placeholder={placeholder}
required={required}
/>
</div>
);
};
export interface NumberInputProps extends InputProps {
value: number;
onChange: (value: number) => void;
min?: number;
max?: number;
}
export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max } = props;
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<input
type="number"
value={value}
onChange={(e) => {
let value = Number(e.target.value);
if (isNaN(value)) value = 0;
if (min !== undefined && value < min) value = min;
if (max !== undefined && value > max) value = max;
onChange(value);
}}
className={inputClasses}
placeholder={placeholder}
required={required}
min={min}
max={max}
/>
</div>
);
};
export interface SelectInputProps extends InputProps {
value: string;
onChange: (value: string) => void;
options: { value: string; label: string }[];
}
export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options } = props;
return (
<div className={classNames(props.className)}>
{label && <label className={labelClasses}>{label}</label>}
<select value={value} onChange={e => onChange(e.target.value)} className={inputClasses}>
{options.map(option => (
<option key={option.value} value={option.value}>
{option.label}
</option>
))}
</select>
</div>
);
};
export interface CheckboxProps {
label?: string;
checked: boolean;
onChange: (checked: boolean) => void;
className?: string;
required?: boolean;
disabled?: boolean;
}
export const Checkbox = (props: CheckboxProps) => {
const { label, checked, onChange, required, disabled } = props;
const id = React.useId(); // Generate unique ID for label association
return (
<div className={classNames('flex items-center', props.className)}>
<div className="relative flex items-start">
<div className="flex items-center h-5">
<input
id={id}
type="checkbox"
checked={checked}
onChange={e => onChange(e.target.checked)}
className="w-4 h-4 rounded border-gray-700 bg-gray-800 text-indigo-600 focus:ring-2 focus:ring-indigo-500 focus:ring-offset-1 focus:ring-offset-gray-900 cursor-pointer transition-colors"
required={required}
disabled={disabled}
/>
</div>
{label && (
<div className="ml-3 text-sm">
<label
htmlFor={id}
className={classNames(
'font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300',
)}
>
{label}
</label>
</div>
)}
</div>
</div>
);
};
interface FormGroupProps {
label?: string;
className?: string;
children: React.ReactNode;
}
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
return (
<div className={classNames(className)}>
{label && <label className={labelClasses}>{label}</label>}
<div className="px-4 space-y-2">{children}</div>
</div>
);
};

3
ui/src/paths.ts Normal file
View File

@@ -0,0 +1,3 @@
import path from 'path';
export const TOOLKIT_ROOT = path.resolve('@', '..', '..');
export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output');

149
ui/src/types.ts Normal file
View File

@@ -0,0 +1,149 @@
/**
* GPU API response
*/
export interface GpuUtilization {
gpu: number;
memory: number;
}
export interface GpuMemory {
total: number;
free: number;
used: number;
}
export interface GpuPower {
draw: number;
limit: number;
}
export interface GpuClocks {
graphics: number;
memory: number;
}
export interface GpuInfo {
index: number;
name: string;
driverVersion: string;
temperature: number;
utilization: GpuUtilization;
memory: GpuMemory;
power: GpuPower;
clocks: GpuClocks;
}
export interface GPUApiResponse {
hasNvidiaSmi: boolean;
gpus: GpuInfo[];
error?: string;
}
/**
* Training configuration
*/
export interface NetworkConfig {
type: 'lora';
linear: number;
linear_alpha: number;
}
export interface SaveConfig {
dtype: string;
save_every: number;
max_step_saves_to_keep: number;
push_to_hub: boolean;
}
export interface DatasetConfig {
folder_path: string;
mask_path: string | null;
mask_min_value: number;
default_caption: string;
caption_ext: string;
caption_dropout_rate: number;
shuffle_tokens?: boolean;
is_reg: boolean;
network_weight: number;
cache_latents_to_disk?: boolean;
resolution: number[];
}
export interface EMAConfig {
use_ema: boolean;
ema_decay: number;
}
export interface TrainConfig {
batch_size: number;
bypass_guidance_embedding?: boolean;
steps: number;
gradient_accumulation: number;
train_unet: boolean;
train_text_encoder: boolean;
gradient_checkpointing: boolean;
noise_scheduler: string;
optimizer: string;
lr: number;
ema_config?: EMAConfig;
dtype: string;
optimizer_params: {
weight_decay: number;
};
}
export interface QuantizeKwargsConfig {
exclude: string[];
}
export interface ModelConfig {
name_or_path: string;
is_flux?: boolean;
is_lumina2?: boolean;
quantize: boolean;
quantize_te: boolean;
quantize_kwargs?: QuantizeKwargsConfig;
}
export interface SampleConfig {
sampler: string;
sample_every: number;
width: number;
height: number;
prompts: string[];
neg: string;
seed: number;
walk_seed: boolean;
guidance_scale: number;
sample_steps: number;
}
export interface ProcessConfig {
type: 'sd_trainer';
training_folder: string;
device: string;
network?: NetworkConfig;
save: SaveConfig;
datasets: DatasetConfig[];
train: TrainConfig;
model: ModelConfig;
sample: SampleConfig;
}
export interface ConfigObject {
name: string;
process: ProcessConfig[];
}
export interface MetaConfig {
name: string;
version: string;
}
export interface JobConfig {
job: string;
config: ConfigObject;
meta: MetaConfig;
}

4
ui/src/utils/basic.ts Normal file
View File

@@ -0,0 +1,4 @@
export const objectCopy = <T>(obj: T): T => {
return JSON.parse(JSON.stringify(obj)) as T;
};

84
ui/src/utils/hooks.tsx Normal file
View File

@@ -0,0 +1,84 @@
import React from 'react';
/**
* Updates a deeply nested value in an object using a string path
* @param obj The object to update
* @param value The new value to set
* @param path String path to the property (e.g. 'config.process[0].model.name_or_path')
* @returns A new object with the updated value
*/
export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
// Create a copy of the original object to maintain immutability
const result = { ...obj };
// if path is not provided, be root path
if (!path) {
path = '';
}
// Split the path into segments
const pathArray = path.split('.').flatMap(segment => {
// Handle array notation like 'process[0]'
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
if (arrayMatch) {
const propName = arrayMatch[1];
const indices = segment
.substring(propName.length)
.match(/\[(\d+)\]/g)
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
// Return property name followed by array indices
return [propName, ...(indices || [])];
}
return segment;
});
// Navigate to the target location
let current: any = result;
for (let i = 0; i < pathArray.length - 1; i++) {
const key = pathArray[i];
// If current key is a number, treat it as an array index
if (typeof key === 'number') {
if (!Array.isArray(current)) {
throw new Error(`Cannot access index ${key} of non-array`);
}
// Create a copy of the array to maintain immutability
current = [...current];
} else {
// For object properties, create a new object if it doesn't exist
if (current[key] === undefined) {
// Check if the next key is a number, if so create an array, otherwise an object
const nextKey = pathArray[i + 1];
current[key] = typeof nextKey === 'number' ? [] : {};
} else {
// Create a shallow copy to maintain immutability
current[key] = Array.isArray(current[key]) ? [...current[key]] : { ...current[key] };
}
}
// Move to the next level
current = current[key];
}
// Set the value at the final path segment
const finalKey = pathArray[pathArray.length - 1];
current[finalKey] = value;
return result;
}
/**
* Custom hook for managing a complex state object with string path updates
* @param initialState The initial state object
* @returns [state, setValue] tuple
*/
export function useNestedState<T>(initialState: T): [T, (value: any, path?: string) => void] {
const [state, setState] = React.useState<T>(initialState);
const setValue = React.useCallback((value: any, path?: string) => {
setState(prevState => setNestedValue(prevState, value, path));
}, []);
return [state, setValue];
}