mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 13:23:56 +00:00
Config ui section is coming along
This commit is contained in:
6
ui/package-lock.json
generated
6
ui/package-lock.json
generated
@@ -9,6 +9,7 @@
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@prisma/client": "^6.3.1",
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.7",
|
||||
"prisma": "^6.3.1",
|
||||
@@ -1182,6 +1183,11 @@
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/classnames": {
|
||||
"version": "2.5.1",
|
||||
"resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz",
|
||||
"integrity": "sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow=="
|
||||
},
|
||||
"node_modules/clean-stack": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz",
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@prisma/client": "^6.3.1",
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.7",
|
||||
"prisma": "^6.3.1",
|
||||
|
||||
@@ -17,7 +17,8 @@ model Settings {
|
||||
model Training {
|
||||
id String @id @default(uuid())
|
||||
name String
|
||||
run_data String // JSON string
|
||||
gpu_id Int
|
||||
job_config String // JSON string
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @updatedAt
|
||||
}
|
||||
102
ui/src/app/api/gpu/route.ts
Normal file
102
ui/src/app/api/gpu/route.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { exec } from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
// Check if nvidia-smi is available
|
||||
const hasNvidiaSmi = await checkNvidiaSmi();
|
||||
|
||||
if (!hasNvidiaSmi) {
|
||||
return NextResponse.json({
|
||||
hasNvidiaSmi: false,
|
||||
gpus: [],
|
||||
error: 'nvidia-smi not found or not accessible',
|
||||
});
|
||||
}
|
||||
|
||||
// Get GPU stats
|
||||
const gpuStats = await getGpuStats();
|
||||
|
||||
return NextResponse.json({
|
||||
hasNvidiaSmi: true,
|
||||
gpus: gpuStats,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error fetching NVIDIA GPU stats:', error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
hasNvidiaSmi: false,
|
||||
gpus: [],
|
||||
error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`,
|
||||
},
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function checkNvidiaSmi(): Promise<boolean> {
|
||||
try {
|
||||
await execAsync('which nvidia-smi');
|
||||
return true;
|
||||
} catch (error) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function getGpuStats() {
|
||||
// Get detailed GPU information in JSON format
|
||||
const { stdout } = await execAsync(
|
||||
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory --format=csv,noheader,nounits',
|
||||
);
|
||||
|
||||
// Parse CSV output
|
||||
const gpus = stdout
|
||||
.trim()
|
||||
.split('\n')
|
||||
.map(line => {
|
||||
const [
|
||||
index,
|
||||
name,
|
||||
driverVersion,
|
||||
temperature,
|
||||
gpuUtil,
|
||||
memoryUtil,
|
||||
memoryTotal,
|
||||
memoryFree,
|
||||
memoryUsed,
|
||||
powerDraw,
|
||||
powerLimit,
|
||||
clockGraphics,
|
||||
clockMemory,
|
||||
] = line.split(', ').map(item => item.trim());
|
||||
|
||||
return {
|
||||
index: parseInt(index),
|
||||
name,
|
||||
driverVersion,
|
||||
temperature: parseInt(temperature),
|
||||
utilization: {
|
||||
gpu: parseInt(gpuUtil),
|
||||
memory: parseInt(memoryUtil),
|
||||
},
|
||||
memory: {
|
||||
total: parseInt(memoryTotal),
|
||||
free: parseInt(memoryFree),
|
||||
used: parseInt(memoryUsed),
|
||||
},
|
||||
power: {
|
||||
draw: parseFloat(powerDraw),
|
||||
limit: parseFloat(powerLimit),
|
||||
},
|
||||
clocks: {
|
||||
graphics: parseInt(clockGraphics),
|
||||
memory: parseInt(clockMemory),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return gpus;
|
||||
}
|
||||
@@ -1,12 +1,21 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultTrainFolder } from '@/paths';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
const settings = await prisma.settings.findMany();
|
||||
return NextResponse.json(settings.reduce((acc, curr) => ({ ...acc, [curr.key]: curr.value }), {}));
|
||||
const settingsObject = settings.reduce((acc: any, setting) => {
|
||||
acc[setting.key] = setting.value;
|
||||
return acc;
|
||||
}, {});
|
||||
// if TRAINING_FOLDER is not set, use default
|
||||
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
||||
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
||||
}
|
||||
return NextResponse.json(settingsObject);
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ export async function GET(request: Request) {
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { id, name, run_data } = body;
|
||||
const { id, name, job_config, gpu_id } = body;
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
@@ -35,7 +35,8 @@ export async function POST(request: Request) {
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
run_data: JSON.stringify(run_data),
|
||||
gpu_id,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
@@ -44,7 +45,8 @@ export async function POST(request: Request) {
|
||||
const training = await prisma.training.create({
|
||||
data: {
|
||||
name,
|
||||
run_data: JSON.stringify(run_data),
|
||||
gpu_id,
|
||||
job_config: JSON.stringify(job_config),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
'use client';
|
||||
|
||||
import GpuMonitor from '@/components/GPUMonitor';
|
||||
|
||||
export default function Dashboard() {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<h1 className="text-3xl font-bold">Dashboard</h1>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{[1, 2, 3].map(i => (
|
||||
<div key={i} className="p-6 bg-gray-800 rounded-lg">
|
||||
<h2 className="text-xl font-semibold mb-2">Card {i}</h2>
|
||||
<p className="text-gray-400">Example dashboard card content</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<GpuMonitor />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
95
ui/src/app/train/jobConfig.ts
Normal file
95
ui/src/app/train/jobConfig.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { JobConfig, DatasetConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
mask_path: null,
|
||||
mask_min_value: 0.1,
|
||||
default_caption: '',
|
||||
caption_ext: 'txt',
|
||||
caption_dropout_rate: 0.05,
|
||||
cache_latents_to_disk: false,
|
||||
is_reg: false,
|
||||
network_weight: 1,
|
||||
resolution: [512, 768, 1024],
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
job: 'extension',
|
||||
config: {
|
||||
name: 'my_first_flex_lora_v1',
|
||||
process: [
|
||||
{
|
||||
type: 'sd_trainer',
|
||||
training_folder: 'output',
|
||||
device: 'cuda:0',
|
||||
network: {
|
||||
type: 'lora',
|
||||
linear: 16,
|
||||
linear_alpha: 16,
|
||||
},
|
||||
save: {
|
||||
dtype: 'bf16',
|
||||
save_every: 250,
|
||||
max_step_saves_to_keep: 4,
|
||||
push_to_hub: false,
|
||||
},
|
||||
datasets: [
|
||||
defaultDatasetConfig
|
||||
],
|
||||
train: {
|
||||
batch_size: 1,
|
||||
bypass_guidance_embedding: true,
|
||||
steps: 2000,
|
||||
gradient_accumulation: 1,
|
||||
train_unet: true,
|
||||
train_text_encoder: false,
|
||||
gradient_checkpointing: true,
|
||||
noise_scheduler: 'flowmatch',
|
||||
optimizer: 'adamw8bit',
|
||||
optimizer_params: {
|
||||
weight_decay: 1e-4
|
||||
},
|
||||
lr: 0.0001,
|
||||
ema_config: {
|
||||
use_ema: true,
|
||||
ema_decay: 0.99,
|
||||
},
|
||||
dtype: 'bf16',
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
is_flux: true,
|
||||
quantize: true,
|
||||
quantize_te: true
|
||||
},
|
||||
sample: {
|
||||
sampler: 'flowmatch',
|
||||
sample_every: 250,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
prompts: [
|
||||
'woman with red hair, playing chess at the park, bomb going off in the background',
|
||||
'a woman holding a coffee cup, in a beanie, sitting at a cafe',
|
||||
'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
|
||||
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
'a bear building a log cabin in the snow covered mountains',
|
||||
'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
|
||||
'hipster man with a beard, building a chair, in a wood shop',
|
||||
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
"a man holding a sign that says, 'this is a sign'",
|
||||
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
],
|
||||
neg: '',
|
||||
seed: 42,
|
||||
walk_seed: true,
|
||||
guidance_scale: 4,
|
||||
sample_steps: 25,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
meta: {
|
||||
name: '[name]',
|
||||
version: '1.0',
|
||||
},
|
||||
};
|
||||
@@ -1,37 +1,41 @@
|
||||
export interface Model {
|
||||
name_or_path: string;
|
||||
model_kwargs?: Record<string, boolean>;
|
||||
train_kwargs?: Record<string, boolean>;
|
||||
name_or_path: string;
|
||||
defaults?: { [key: string]: any };
|
||||
}
|
||||
|
||||
export interface Option {
|
||||
model: Model[];
|
||||
model: Model[];
|
||||
}
|
||||
|
||||
|
||||
export const options = {
|
||||
model: [
|
||||
{
|
||||
name_or_path: "ostris/Flex.1-alpha",
|
||||
model_kwargs: {
|
||||
"is_flux": true
|
||||
},
|
||||
train_kwargs: {
|
||||
"bypass_guidance_embedding": true
|
||||
}
|
||||
},
|
||||
{
|
||||
name_or_path: "black-forest-labs/FLUX.1-dev",
|
||||
model_kwargs: {
|
||||
"is_flux": true
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: "Alpha-VLLM/Lumina-Image-2.0",
|
||||
model_kwargs: {
|
||||
"is_lumina2": true
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
} as Option;
|
||||
model: [
|
||||
{
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'black-forest-labs/FLUX.1-dev',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_lumina2': [true, false],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as Option;
|
||||
|
||||
@@ -1,34 +1,26 @@
|
||||
'use client';
|
||||
// todo update training folder from settings
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { options } from './options';
|
||||
|
||||
interface TrainingData {
|
||||
modelConfig: {
|
||||
name_or_path: string;
|
||||
steps: number;
|
||||
batchSize: number;
|
||||
learningRate: number;
|
||||
};
|
||||
}
|
||||
|
||||
const defaultTrainingData: TrainingData = {
|
||||
modelConfig: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
steps: 100,
|
||||
batchSize: 32,
|
||||
learningRate: 0.001,
|
||||
},
|
||||
};
|
||||
import { GPUApiResponse } from '@/types';
|
||||
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { useNestedState } from '@/utils/hooks';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
import { Trash, X } from 'lucide-react';
|
||||
|
||||
export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const runId = searchParams.get('id');
|
||||
const [gpuID, setGpuID] = useState<number | null>(null);
|
||||
const [gpuList, setGpuList] = useState<number[]>([]);
|
||||
|
||||
const [name, setName] = useState('');
|
||||
const [trainingData, setTrainingData] = useState<TrainingData>(defaultTrainingData);
|
||||
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
useEffect(() => {
|
||||
@@ -36,13 +28,42 @@ export default function TrainingForm() {
|
||||
fetch(`/api/training?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setName(data.name);
|
||||
setTrainingData(JSON.parse(data.run_data));
|
||||
setGpuID(data.gpu_id);
|
||||
setJobConfig(JSON.parse(data.job_config));
|
||||
})
|
||||
.catch(error => console.error('Error fetching training:', error));
|
||||
}
|
||||
}, [runId]);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchGpuInfo = async () => {
|
||||
try {
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
setGpuList(data.gpus.map(gpu => gpu.index).sort());
|
||||
} catch (err) {
|
||||
console.log(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
} finally {
|
||||
// setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetch('/api/settings')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setJobConfig(data.TRAINING_FOLDER, 'config.process[0].training_folder');
|
||||
})
|
||||
.catch(error => console.error('Error fetching settings:', error));
|
||||
|
||||
// Fetch immediately on component mount
|
||||
fetchGpuInfo();
|
||||
}, []);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setStatus('saving');
|
||||
@@ -55,8 +76,9 @@ export default function TrainingForm() {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
id: runId,
|
||||
name,
|
||||
run_data: trainingData,
|
||||
name: jobConfig.config.name,
|
||||
gpu_id: gpuID,
|
||||
job_config: jobConfig,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -75,64 +97,428 @@ export default function TrainingForm() {
|
||||
}
|
||||
};
|
||||
|
||||
const updateSection = (section: keyof TrainingData, data: any) => {
|
||||
setTrainingData(prev => ({
|
||||
...prev,
|
||||
[section]: { ...prev[section], ...data },
|
||||
}));
|
||||
};
|
||||
|
||||
const modelOptions = options.model.map(model => model.name_or_path);
|
||||
|
||||
return (
|
||||
<div className="max-w-4xl mx-auto space-y-8 pb-12">
|
||||
<h1 className="text-3xl font-bold mb-8">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
|
||||
<div className="space-y-6">
|
||||
<h1 className="text-xl font-bold mb-8">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className="space-y-4">
|
||||
<label htmlFor="name" className="block text-sm font-medium mb-2">
|
||||
Training Name
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="name"
|
||||
value={name}
|
||||
onChange={e => setName(e.target.value)}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter training name"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<Card title="Job Settings">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
onChange={value => setJobConfig(value, 'config.name')}
|
||||
placeholder="Enter training name"
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="GPU ID"
|
||||
value={`${gpuID}`}
|
||||
className="pt-2"
|
||||
onChange={value => setGpuID(parseInt(value))}
|
||||
options={gpuList.map(gpu => ({ value: `${gpu}`, label: `GPU #${gpu}` }))}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<section className="space-y-4 p-6 bg-gray-900 rounded-lg">
|
||||
<h2 className="text-xl font-bold mb-4">Model Configuration</h2>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2">Model</label>
|
||||
<select
|
||||
value={trainingData.modelConfig.name_or_path}
|
||||
onChange={e => updateSection('modelConfig', { name_or_path: e.target.value })}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
|
||||
>
|
||||
{modelOptions.map(model => (
|
||||
<option key={model} value={model}>
|
||||
{model}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2">Epochs</label>
|
||||
<input
|
||||
type="number"
|
||||
value={trainingData.modelConfig.steps}
|
||||
onChange={e => updateSection('modelConfig', { steps: Number(e.target.value) })}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
|
||||
{/* Model Configuration Section */}
|
||||
<Card title="Model Configuration">
|
||||
<SelectInput
|
||||
label="Name or Path"
|
||||
value={jobConfig.config.process[0].model.name_or_path}
|
||||
onChange={value => {
|
||||
// see if model changed
|
||||
const currentModel = options.model.find(
|
||||
model => model.name_or_path === jobConfig.config.process[0].model.name_or_path,
|
||||
);
|
||||
if (!currentModel || currentModel.name_or_path === value) {
|
||||
// model has not changed
|
||||
return;
|
||||
}
|
||||
// revert defaults from previous model
|
||||
for (const key in currentModel.defaults) {
|
||||
setJobConfig(currentModel.defaults[key][1], key);
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.name_or_path');
|
||||
// update the defaults when a model is selected
|
||||
const model = options.model.find(model => model.name_or_path === value);
|
||||
if (model?.defaults) {
|
||||
for (const key in model.defaults) {
|
||||
setJobConfig(model.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
}}
|
||||
options={options.model.map(model => ({
|
||||
value: model.name_or_path,
|
||||
label: model.name_or_path,
|
||||
}))}
|
||||
/>
|
||||
<FormGroup label="Quantize" className="pt-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Text Encoder"
|
||||
checked={jobConfig.config.process[0].model.quantize_te}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
|
||||
/>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
{jobConfig.config.process[0].network?.linear && (
|
||||
<Card title="LoRA Configuration">
|
||||
<NumberInput
|
||||
label="Linear Rank"
|
||||
value={jobConfig.config.process[0].network.linear}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].network.linear');
|
||||
setJobConfig(value, 'config.process[0].network.linear_alpha');
|
||||
}}
|
||||
placeholder="eg. 16"
|
||||
min={1}
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Save Configuration">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
value={jobConfig.config.process[0].save.dtype}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.dtype')}
|
||||
options={[
|
||||
{ value: 'bf16', label: 'BF16' },
|
||||
{ value: 'fp16', label: 'FP16' },
|
||||
{ value: 'fp32', label: 'FP32' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Save Every"
|
||||
value={jobConfig.config.process[0].save.save_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.save_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Max Step Saves to Keep"
|
||||
value={jobConfig.config.process[0].save.max_step_saves_to_keep}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.batch_size}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.batch_size')}
|
||||
placeholder="eg. 4"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Gradient Accumulation"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.gradient_accumulation}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.gradient_accumulation')}
|
||||
placeholder="eg. 1"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Steps"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.steps')}
|
||||
placeholder="eg. 2000"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Optimizer"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.optimizer}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer')}
|
||||
options={[
|
||||
{ value: 'adamw8bit', label: 'AdamW8Bit' },
|
||||
{ value: 'adafactor', label: 'Adafactor' },
|
||||
]}
|
||||
/>
|
||||
<NumberInput
|
||||
label="Learning Rate"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.lr}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.lr')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Weight Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.optimizer_params.weight_decay}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')}
|
||||
placeholder="eg. 0.0001"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Datasets">
|
||||
<>
|
||||
{jobConfig.config.process[0].datasets.map((dataset, i) => (
|
||||
<div key={i} className="p-4 rounded-lg bg-gray-800 relative">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].datasets.filter((_, index) => index !== i),
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="absolute top-2 right-2 bg-red-800 hover:bg-red-700 rounded-full p-1 text-sm transition-colors"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
<h2 className="text-lg font-bold mb-4">Dataset {i + 1}</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<TextInput
|
||||
label="Folder Path"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
placeholder="eg. /path/to/images/folder"
|
||||
required
|
||||
/>
|
||||
<TextInput
|
||||
label="Mask Folder Path"
|
||||
className="pt-2"
|
||||
value={dataset.mask_path || ''}
|
||||
onChange={value => {
|
||||
let setValue: string | null = value;
|
||||
if (!setValue || setValue.trim() === '') {
|
||||
setValue = null;
|
||||
}
|
||||
setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`);
|
||||
}}
|
||||
placeholder="eg. /path/to/masks/folder"
|
||||
/>
|
||||
<NumberInput
|
||||
label="Mask Min Value"
|
||||
className="pt-2"
|
||||
value={dataset.mask_min_value}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)}
|
||||
placeholder="eg. 0.1"
|
||||
min={0}
|
||||
max={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<TextInput
|
||||
label="Default Caption"
|
||||
value={dataset.default_caption}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)}
|
||||
placeholder="eg. A photo of a cat"
|
||||
/>
|
||||
<TextInput
|
||||
label="Caption Extension"
|
||||
className="pt-2"
|
||||
value={dataset.caption_ext}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)}
|
||||
placeholder="eg. txt"
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Caption Dropout Rate"
|
||||
className="pt-2"
|
||||
value={dataset.caption_dropout_rate}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)}
|
||||
placeholder="eg. 0.05"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Settings" className="">
|
||||
<Checkbox
|
||||
label="Cache Latents to Disk"
|
||||
checked={dataset.cache_latents_to_disk || false}
|
||||
onChange={value =>
|
||||
setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`)
|
||||
}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Is Regularization"
|
||||
className="pt-2"
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Resolutions" className="pt-2">
|
||||
{[256, 512, 768, 1024, 1280].map(res => (
|
||||
<Checkbox
|
||||
key={res}
|
||||
label={res.toString()}
|
||||
checked={dataset.resolution.includes(res)}
|
||||
onChange={value => {
|
||||
const resolutions = dataset.resolution.includes(res)
|
||||
? dataset.resolution.filter(r => r !== res)
|
||||
: [...dataset.resolution, res];
|
||||
setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</FormGroup>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Dataset
|
||||
</button>
|
||||
</>
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Sample Every"
|
||||
value={jobConfig.config.process[0].sample.sample_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_every')}
|
||||
placeholder="eg. 250"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
<SelectInput
|
||||
label="Sampler"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].sample.sampler}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sampler')}
|
||||
options={[{ value: 'flowmatch', label: 'FlowMatch' }]}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Guidance Scale"
|
||||
value={jobConfig.config.process[0].sample.guidance_scale}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.guidance_scale')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Sample Steps"
|
||||
value={jobConfig.config.process[0].sample.sample_steps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.sample_steps')}
|
||||
placeholder="eg. 1"
|
||||
className="pt-2"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Width"
|
||||
value={jobConfig.config.process[0].sample.width}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
||||
placeholder="eg. 1024"
|
||||
min={256}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="Height"
|
||||
value={jobConfig.config.process[0].sample.height}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
||||
placeholder="eg. 1024"
|
||||
className="pt-2"
|
||||
min={256}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Seed"
|
||||
value={jobConfig.config.process[0].sample.seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.seed')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<Checkbox
|
||||
label="Walk Seed"
|
||||
className="pt-4 pl-2"
|
||||
checked={jobConfig.config.process[0].sample.walk_seed}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
|
||||
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
|
||||
<div key={i} className="flex items-center space-x-2">
|
||||
<div className="flex-1">
|
||||
<TextInput
|
||||
value={prompt}
|
||||
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
|
||||
placeholder="Enter prompt"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig(
|
||||
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
|
||||
'config.process[0].sample.prompts',
|
||||
)
|
||||
}
|
||||
className="rounded-full p-1 text-sm"
|
||||
>
|
||||
<X />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setJobConfig([...jobConfig.config.process[0].sample.prompts, ''], 'config.process[0].sample.prompts')
|
||||
}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Prompt
|
||||
</button>
|
||||
</FormGroup>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
|
||||
15
ui/src/components/Card.tsx
Normal file
15
ui/src/components/Card.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
interface CardProps {
|
||||
title?: string;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
const Card: React.FC<CardProps> = ({ title, children }) => {
|
||||
return (
|
||||
<section className="space-y-4 p-6 bg-gray-900 rounded-lg">
|
||||
{title && <h2 className="text-lg mb-4 font-semibold uppercase text-gray-500">{title}</h2>}
|
||||
{children ? children : null}
|
||||
</section>
|
||||
);
|
||||
};
|
||||
|
||||
export default Card;
|
||||
195
ui/src/components/GPUMonitor.tsx
Normal file
195
ui/src/components/GPUMonitor.tsx
Normal file
@@ -0,0 +1,195 @@
|
||||
// components/GpuMonitor.tsx
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { GPUApiResponse } from '@/types';
|
||||
|
||||
const GpuMonitor: React.FC = () => {
|
||||
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchGpuInfo = async () => {
|
||||
try {
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
setGpuData(data);
|
||||
setLastUpdated(new Date());
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch immediately on component mount
|
||||
fetchGpuInfo();
|
||||
|
||||
// Set up interval to fetch every 1 seconds
|
||||
const intervalId = setInterval(fetchGpuInfo, 1000);
|
||||
|
||||
// Clean up interval on component unmount
|
||||
return () => clearInterval(intervalId);
|
||||
}, []);
|
||||
|
||||
// Helper to format memory values
|
||||
const formatMemory = (mb: number): string => {
|
||||
if (mb >= 1024) {
|
||||
return `${(mb / 1024).toFixed(2)} GB`;
|
||||
}
|
||||
return `${mb} MB`;
|
||||
};
|
||||
|
||||
// Helper to determine background color based on utilization
|
||||
const getUtilizationColor = (percent: number): string => {
|
||||
if (percent < 30) return 'bg-green-100';
|
||||
if (percent < 70) return 'bg-yellow-100';
|
||||
return 'bg-red-100';
|
||||
};
|
||||
|
||||
// Helper to determine text color based on utilization
|
||||
const getUtilizationTextColor = (percent: number): string => {
|
||||
if (percent < 30) return 'text-green-800';
|
||||
if (percent < 70) return 'text-yellow-800';
|
||||
return 'text-red-800';
|
||||
};
|
||||
|
||||
// Helper to determine temperature color
|
||||
const getTemperatureColor = (temp: number): string => {
|
||||
if (temp < 50) return 'text-green-600';
|
||||
if (temp < 80) return 'text-yellow-600';
|
||||
return 'text-red-600';
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">Error!</strong>
|
||||
<span className="block sm:inline"> {error}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPU data available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData.hasNvidiaSmi) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
||||
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
||||
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (gpuData.gpus.length === 0) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-2">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<h1 className="text-lg font-bold">GPU Monitor</h1>
|
||||
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-3">
|
||||
{gpuData.gpus.map(gpu => (
|
||||
<div
|
||||
key={gpu.index}
|
||||
className="bg-gray-900 rounded-lg shadow-lg overflow-hidden hover:shadow-xl transition-shadow duration-300 px-2 py-2"
|
||||
>
|
||||
<div className="bg-gray-800 text-white px-2 py-1 flex justify-between items-center">
|
||||
<h2 className="font-bold text-sm truncate">{gpu.name}</h2>
|
||||
<span className="text-xs bg-gray-700 rounded px-1 py-0.5">GPU #{gpu.index}</span>
|
||||
</div>
|
||||
|
||||
<div className="p-2">
|
||||
<div className="mb-2 flex items-center">
|
||||
<p className="text-xs text-gray-500 mr-1">Temperature:</p>
|
||||
<p className={`text-sm font-bold ${getTemperatureColor(gpu.temperature)}`}>{gpu.temperature}°C</p>
|
||||
</div>
|
||||
|
||||
<div className="">
|
||||
<p className="text-xs text-gray-600 mb-0.5">GPU Utilization</p>
|
||||
<div className="w-full bg-gray-500 rounded-full h-1.5">
|
||||
<div
|
||||
className={`h-1.5 rounded-full ${gpu.utilization.gpu < 30 ? 'bg-green-500' : gpu.utilization.gpu < 70 ? 'bg-yellow-500' : 'bg-red-500'}`}
|
||||
style={{ width: `${gpu.utilization.gpu}%` }}
|
||||
></div>
|
||||
</div>
|
||||
<p className="text-right text-xs mt-0.5">{gpu.utilization.gpu}%</p>
|
||||
</div>
|
||||
|
||||
<div className="mb-2">
|
||||
<p className="text-xs text-gray-600 mb-0.5">Memory Utilization</p>
|
||||
<div className="w-full bg-gray-500 rounded-full h-1.5">
|
||||
<div
|
||||
className="h-1.5 rounded-full bg-blue-500"
|
||||
style={{ width: `${(gpu.memory.used / gpu.memory.total) * 100}%` }}
|
||||
></div>
|
||||
</div>
|
||||
<div className="flex justify-between text-xs mt-0.5">
|
||||
<span>
|
||||
{formatMemory(gpu.memory.used)} / {formatMemory(gpu.memory.total)}
|
||||
</span>
|
||||
<span>{((gpu.memory.used / gpu.memory.total) * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid grid-cols-2 gap-2 mb-2">
|
||||
<div>
|
||||
<p className="text-xs text-gray-500 mb-0.5">Power</p>
|
||||
<p className="text-sm font-medium">
|
||||
{gpu.power.draw.toFixed(1)}W / {gpu.power.limit.toFixed(1)}W
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-xs text-gray-500 mb-0.5">Memory Clock</p>
|
||||
<p className="text-sm font-medium">{gpu.clocks.memory} MHz</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-1 pt-1 border-t border-gray-600 grid grid-cols-2 gap-2">
|
||||
<div>
|
||||
<p className="text-xs text-gray-500 mb-0.5">Graphics Clock</p>
|
||||
<p className="text-sm font-medium">{gpu.clocks.graphics} MHz</p>
|
||||
</div>
|
||||
<div className="">
|
||||
<p className="text-xs text-gray-500 mb-0.5">Driver Version</p>
|
||||
<p className="text-sm font-medium">{gpu.driverVersion}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default GpuMonitor;
|
||||
150
ui/src/components/formInputs.tsx
Normal file
150
ui/src/components/formInputs.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
import React from 'react';
|
||||
import classNames from 'classnames';
|
||||
|
||||
const labelClasses = 'block text-sm mb-2 text-gray-300';
|
||||
const inputClasses =
|
||||
'w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent';
|
||||
|
||||
export interface InputProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
placeholder?: string;
|
||||
required?: boolean;
|
||||
}
|
||||
|
||||
export interface TextInputProps extends InputProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
type?: 'text' | 'password';
|
||||
}
|
||||
|
||||
export const TextInput = (props: TextInputProps) => {
|
||||
const { label, value, onChange, placeholder, required } = props;
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<input
|
||||
type={props.type || 'text'}
|
||||
value={value}
|
||||
onChange={e => onChange(e.target.value)}
|
||||
className={inputClasses}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface NumberInputProps extends InputProps {
|
||||
value: number;
|
||||
onChange: (value: number) => void;
|
||||
min?: number;
|
||||
max?: number;
|
||||
}
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<input
|
||||
type="number"
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
let value = Number(e.target.value);
|
||||
if (isNaN(value)) value = 0;
|
||||
if (min !== undefined && value < min) value = min;
|
||||
if (max !== undefined && value > max) value = max;
|
||||
onChange(value);
|
||||
}}
|
||||
className={inputClasses}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
min={min}
|
||||
max={max}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface SelectInputProps extends InputProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
options: { value: string; label: string }[];
|
||||
}
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options } = props;
|
||||
return (
|
||||
<div className={classNames(props.className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<select value={value} onChange={e => onChange(e.target.value)} className={inputClasses}>
|
||||
{options.map(option => (
|
||||
<option key={option.value} value={option.value}>
|
||||
{option.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export interface CheckboxProps {
|
||||
label?: string;
|
||||
checked: boolean;
|
||||
onChange: (checked: boolean) => void;
|
||||
className?: string;
|
||||
required?: boolean;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export const Checkbox = (props: CheckboxProps) => {
|
||||
const { label, checked, onChange, required, disabled } = props;
|
||||
const id = React.useId(); // Generate unique ID for label association
|
||||
|
||||
return (
|
||||
<div className={classNames('flex items-center', props.className)}>
|
||||
<div className="relative flex items-start">
|
||||
<div className="flex items-center h-5">
|
||||
<input
|
||||
id={id}
|
||||
type="checkbox"
|
||||
checked={checked}
|
||||
onChange={e => onChange(e.target.checked)}
|
||||
className="w-4 h-4 rounded border-gray-700 bg-gray-800 text-indigo-600 focus:ring-2 focus:ring-indigo-500 focus:ring-offset-1 focus:ring-offset-gray-900 cursor-pointer transition-colors"
|
||||
required={required}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
{label && (
|
||||
<div className="ml-3 text-sm">
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface FormGroupProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children }) => {
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && <label className={labelClasses}>{label}</label>}
|
||||
<div className="px-4 space-y-2">{children}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
3
ui/src/paths.ts
Normal file
3
ui/src/paths.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
import path from 'path';
|
||||
export const TOOLKIT_ROOT = path.resolve('@', '..', '..');
|
||||
export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output');
|
||||
149
ui/src/types.ts
Normal file
149
ui/src/types.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
/**
|
||||
* GPU API response
|
||||
*/
|
||||
|
||||
export interface GpuUtilization {
|
||||
gpu: number;
|
||||
memory: number;
|
||||
}
|
||||
|
||||
export interface GpuMemory {
|
||||
total: number;
|
||||
free: number;
|
||||
used: number;
|
||||
}
|
||||
|
||||
export interface GpuPower {
|
||||
draw: number;
|
||||
limit: number;
|
||||
}
|
||||
|
||||
export interface GpuClocks {
|
||||
graphics: number;
|
||||
memory: number;
|
||||
}
|
||||
|
||||
export interface GpuInfo {
|
||||
index: number;
|
||||
name: string;
|
||||
driverVersion: string;
|
||||
temperature: number;
|
||||
utilization: GpuUtilization;
|
||||
memory: GpuMemory;
|
||||
power: GpuPower;
|
||||
clocks: GpuClocks;
|
||||
}
|
||||
|
||||
export interface GPUApiResponse {
|
||||
hasNvidiaSmi: boolean;
|
||||
gpus: GpuInfo[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Training configuration
|
||||
*/
|
||||
|
||||
export interface NetworkConfig {
|
||||
type: 'lora';
|
||||
linear: number;
|
||||
linear_alpha: number;
|
||||
}
|
||||
|
||||
export interface SaveConfig {
|
||||
dtype: string;
|
||||
save_every: number;
|
||||
max_step_saves_to_keep: number;
|
||||
push_to_hub: boolean;
|
||||
}
|
||||
|
||||
export interface DatasetConfig {
|
||||
folder_path: string;
|
||||
mask_path: string | null;
|
||||
mask_min_value: number;
|
||||
default_caption: string;
|
||||
caption_ext: string;
|
||||
caption_dropout_rate: number;
|
||||
shuffle_tokens?: boolean;
|
||||
is_reg: boolean;
|
||||
network_weight: number;
|
||||
cache_latents_to_disk?: boolean;
|
||||
resolution: number[];
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
use_ema: boolean;
|
||||
ema_decay: number;
|
||||
}
|
||||
|
||||
export interface TrainConfig {
|
||||
batch_size: number;
|
||||
bypass_guidance_embedding?: boolean;
|
||||
steps: number;
|
||||
gradient_accumulation: number;
|
||||
train_unet: boolean;
|
||||
train_text_encoder: boolean;
|
||||
gradient_checkpointing: boolean;
|
||||
noise_scheduler: string;
|
||||
optimizer: string;
|
||||
lr: number;
|
||||
ema_config?: EMAConfig;
|
||||
dtype: string;
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
exclude: string[];
|
||||
}
|
||||
|
||||
export interface ModelConfig {
|
||||
name_or_path: string;
|
||||
is_flux?: boolean;
|
||||
is_lumina2?: boolean;
|
||||
quantize: boolean;
|
||||
quantize_te: boolean;
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
sampler: string;
|
||||
sample_every: number;
|
||||
width: number;
|
||||
height: number;
|
||||
prompts: string[];
|
||||
neg: string;
|
||||
seed: number;
|
||||
walk_seed: boolean;
|
||||
guidance_scale: number;
|
||||
sample_steps: number;
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
type: 'sd_trainer';
|
||||
training_folder: string;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
save: SaveConfig;
|
||||
datasets: DatasetConfig[];
|
||||
train: TrainConfig;
|
||||
model: ModelConfig;
|
||||
sample: SampleConfig;
|
||||
}
|
||||
|
||||
export interface ConfigObject {
|
||||
name: string;
|
||||
process: ProcessConfig[];
|
||||
}
|
||||
|
||||
export interface MetaConfig {
|
||||
name: string;
|
||||
version: string;
|
||||
}
|
||||
|
||||
export interface JobConfig {
|
||||
job: string;
|
||||
config: ConfigObject;
|
||||
meta: MetaConfig;
|
||||
}
|
||||
4
ui/src/utils/basic.ts
Normal file
4
ui/src/utils/basic.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export const objectCopy = <T>(obj: T): T => {
|
||||
return JSON.parse(JSON.stringify(obj)) as T;
|
||||
};
|
||||
|
||||
84
ui/src/utils/hooks.tsx
Normal file
84
ui/src/utils/hooks.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Updates a deeply nested value in an object using a string path
|
||||
* @param obj The object to update
|
||||
* @param value The new value to set
|
||||
* @param path String path to the property (e.g. 'config.process[0].model.name_or_path')
|
||||
* @returns A new object with the updated value
|
||||
*/
|
||||
export function setNestedValue<T, V>(obj: T, value: V, path?: string): T {
|
||||
// Create a copy of the original object to maintain immutability
|
||||
const result = { ...obj };
|
||||
|
||||
// if path is not provided, be root path
|
||||
if (!path) {
|
||||
path = '';
|
||||
}
|
||||
|
||||
// Split the path into segments
|
||||
const pathArray = path.split('.').flatMap(segment => {
|
||||
// Handle array notation like 'process[0]'
|
||||
const arrayMatch = segment.match(/^([^\[]+)(\[\d+\])+/);
|
||||
if (arrayMatch) {
|
||||
const propName = arrayMatch[1];
|
||||
const indices = segment
|
||||
.substring(propName.length)
|
||||
.match(/\[(\d+)\]/g)
|
||||
?.map(idx => parseInt(idx.substring(1, idx.length - 1)));
|
||||
|
||||
// Return property name followed by array indices
|
||||
return [propName, ...(indices || [])];
|
||||
}
|
||||
return segment;
|
||||
});
|
||||
|
||||
// Navigate to the target location
|
||||
let current: any = result;
|
||||
for (let i = 0; i < pathArray.length - 1; i++) {
|
||||
const key = pathArray[i];
|
||||
|
||||
// If current key is a number, treat it as an array index
|
||||
if (typeof key === 'number') {
|
||||
if (!Array.isArray(current)) {
|
||||
throw new Error(`Cannot access index ${key} of non-array`);
|
||||
}
|
||||
// Create a copy of the array to maintain immutability
|
||||
current = [...current];
|
||||
} else {
|
||||
// For object properties, create a new object if it doesn't exist
|
||||
if (current[key] === undefined) {
|
||||
// Check if the next key is a number, if so create an array, otherwise an object
|
||||
const nextKey = pathArray[i + 1];
|
||||
current[key] = typeof nextKey === 'number' ? [] : {};
|
||||
} else {
|
||||
// Create a shallow copy to maintain immutability
|
||||
current[key] = Array.isArray(current[key]) ? [...current[key]] : { ...current[key] };
|
||||
}
|
||||
}
|
||||
|
||||
// Move to the next level
|
||||
current = current[key];
|
||||
}
|
||||
|
||||
// Set the value at the final path segment
|
||||
const finalKey = pathArray[pathArray.length - 1];
|
||||
current[finalKey] = value;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for managing a complex state object with string path updates
|
||||
* @param initialState The initial state object
|
||||
* @returns [state, setValue] tuple
|
||||
*/
|
||||
export function useNestedState<T>(initialState: T): [T, (value: any, path?: string) => void] {
|
||||
const [state, setState] = React.useState<T>(initialState);
|
||||
|
||||
const setValue = React.useCallback((value: any, path?: string) => {
|
||||
setState(prevState => setNestedValue(prevState, value, path));
|
||||
}, []);
|
||||
|
||||
return [state, setValue];
|
||||
}
|
||||
Reference in New Issue
Block a user