From cef7d9e594e7570561fbcba6bac9adeb5c49830c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 19 Feb 2025 07:52:24 -0700 Subject: [PATCH] Config ui section is coming along --- ui/package-lock.json | 6 + ui/package.json | 1 + ui/prisma/schema.prisma | 3 +- ui/src/app/api/gpu/route.ts | 102 ++++++ ui/src/app/api/settings/route.ts | 11 +- ui/src/app/api/training/route.ts | 8 +- ui/src/app/dashboard/page.tsx | 13 +- ui/src/app/train/jobConfig.ts | 95 ++++++ ui/src/app/train/options.ts | 64 ++-- ui/src/app/train/page.tsx | 540 ++++++++++++++++++++++++++----- ui/src/components/Card.tsx | 15 + ui/src/components/GPUMonitor.tsx | 195 +++++++++++ ui/src/components/formInputs.tsx | 150 +++++++++ ui/src/paths.ts | 3 + ui/src/types.ts | 149 +++++++++ ui/src/utils/basic.ts | 4 + ui/src/utils/hooks.tsx | 84 +++++ 17 files changed, 1323 insertions(+), 120 deletions(-) create mode 100644 ui/src/app/api/gpu/route.ts create mode 100644 ui/src/app/train/jobConfig.ts create mode 100644 ui/src/components/Card.tsx create mode 100644 ui/src/components/GPUMonitor.tsx create mode 100644 ui/src/components/formInputs.tsx create mode 100644 ui/src/paths.ts create mode 100644 ui/src/types.ts create mode 100644 ui/src/utils/basic.ts create mode 100644 ui/src/utils/hooks.tsx diff --git a/ui/package-lock.json b/ui/package-lock.json index 32ffebd3..c8965cdb 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -9,6 +9,7 @@ "version": "0.1.0", "dependencies": { "@prisma/client": "^6.3.1", + "classnames": "^2.5.1", "lucide-react": "^0.475.0", "next": "15.1.7", "prisma": "^6.3.1", @@ -1182,6 +1183,11 @@ "node": ">=10" } }, + "node_modules/classnames": { + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz", + "integrity": "sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow==" + }, "node_modules/clean-stack": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz", diff --git a/ui/package.json b/ui/package.json index a4d551ca..81d9c913 100644 --- a/ui/package.json +++ b/ui/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "@prisma/client": "^6.3.1", + "classnames": "^2.5.1", "lucide-react": "^0.475.0", "next": "15.1.7", "prisma": "^6.3.1", diff --git a/ui/prisma/schema.prisma b/ui/prisma/schema.prisma index 66bf99ec..4f3380e3 100644 --- a/ui/prisma/schema.prisma +++ b/ui/prisma/schema.prisma @@ -17,7 +17,8 @@ model Settings { model Training { id String @id @default(uuid()) name String - run_data String // JSON string + gpu_id Int + job_config String // JSON string created_at DateTime @default(now()) updated_at DateTime @updatedAt } \ No newline at end of file diff --git a/ui/src/app/api/gpu/route.ts b/ui/src/app/api/gpu/route.ts new file mode 100644 index 00000000..93a084d2 --- /dev/null +++ b/ui/src/app/api/gpu/route.ts @@ -0,0 +1,102 @@ +import { NextResponse } from 'next/server'; +import { exec } from 'child_process'; +import { promisify } from 'util'; + +const execAsync = promisify(exec); + +export async function GET() { + try { + // Check if nvidia-smi is available + const hasNvidiaSmi = await checkNvidiaSmi(); + + if (!hasNvidiaSmi) { + return NextResponse.json({ + hasNvidiaSmi: false, + gpus: [], + error: 'nvidia-smi not found or not accessible', + }); + } + + // Get GPU stats + const gpuStats = await getGpuStats(); + + return NextResponse.json({ + hasNvidiaSmi: true, + gpus: gpuStats, + }); + } catch (error) { + console.error('Error fetching NVIDIA GPU stats:', error); + return NextResponse.json( + { + hasNvidiaSmi: false, + gpus: [], + error: `Failed to fetch GPU stats: ${error instanceof Error ? error.message : String(error)}`, + }, + { status: 500 }, + ); + } +} + +async function checkNvidiaSmi(): Promise { + try { + await execAsync('which nvidia-smi'); + return true; + } catch (error) { + return false; + } +} + +async function getGpuStats() { + // Get detailed GPU information in JSON format + const { stdout } = await execAsync( + 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory --format=csv,noheader,nounits', + ); + + // Parse CSV output + const gpus = stdout + .trim() + .split('\n') + .map(line => { + const [ + index, + name, + driverVersion, + temperature, + gpuUtil, + memoryUtil, + memoryTotal, + memoryFree, + memoryUsed, + powerDraw, + powerLimit, + clockGraphics, + clockMemory, + ] = line.split(', ').map(item => item.trim()); + + return { + index: parseInt(index), + name, + driverVersion, + temperature: parseInt(temperature), + utilization: { + gpu: parseInt(gpuUtil), + memory: parseInt(memoryUtil), + }, + memory: { + total: parseInt(memoryTotal), + free: parseInt(memoryFree), + used: parseInt(memoryUsed), + }, + power: { + draw: parseFloat(powerDraw), + limit: parseFloat(powerLimit), + }, + clocks: { + graphics: parseInt(clockGraphics), + memory: parseInt(clockMemory), + }, + }; + }); + + return gpus; +} diff --git a/ui/src/app/api/settings/route.ts b/ui/src/app/api/settings/route.ts index 3346f0a9..8062afde 100644 --- a/ui/src/app/api/settings/route.ts +++ b/ui/src/app/api/settings/route.ts @@ -1,12 +1,21 @@ import { NextResponse } from 'next/server'; import { PrismaClient } from '@prisma/client'; +import { defaultTrainFolder } from '@/paths'; const prisma = new PrismaClient(); export async function GET() { try { const settings = await prisma.settings.findMany(); - return NextResponse.json(settings.reduce((acc, curr) => ({ ...acc, [curr.key]: curr.value }), {})); + const settingsObject = settings.reduce((acc: any, setting) => { + acc[setting.key] = setting.value; + return acc; + }, {}); + // if TRAINING_FOLDER is not set, use default + if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') { + settingsObject.TRAINING_FOLDER = defaultTrainFolder; + } + return NextResponse.json(settingsObject); } catch (error) { return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 }); } diff --git a/ui/src/app/api/training/route.ts b/ui/src/app/api/training/route.ts index 1020bd24..64d3d73a 100644 --- a/ui/src/app/api/training/route.ts +++ b/ui/src/app/api/training/route.ts @@ -27,7 +27,7 @@ export async function GET(request: Request) { export async function POST(request: Request) { try { const body = await request.json(); - const { id, name, run_data } = body; + const { id, name, job_config, gpu_id } = body; if (id) { // Update existing training @@ -35,7 +35,8 @@ export async function POST(request: Request) { where: { id }, data: { name, - run_data: JSON.stringify(run_data), + gpu_id, + job_config: JSON.stringify(job_config), }, }); return NextResponse.json(training); @@ -44,7 +45,8 @@ export async function POST(request: Request) { const training = await prisma.training.create({ data: { name, - run_data: JSON.stringify(run_data), + gpu_id, + job_config: JSON.stringify(job_config), }, }); return NextResponse.json(training); diff --git a/ui/src/app/dashboard/page.tsx b/ui/src/app/dashboard/page.tsx index 4c9eb54d..fd05095f 100644 --- a/ui/src/app/dashboard/page.tsx +++ b/ui/src/app/dashboard/page.tsx @@ -1,15 +1,12 @@ +'use client'; + +import GpuMonitor from '@/components/GPUMonitor'; + export default function Dashboard() { return (

Dashboard

-
- {[1, 2, 3].map(i => ( -
-

Card {i}

-

Example dashboard card content

-
- ))} -
+
); } diff --git a/ui/src/app/train/jobConfig.ts b/ui/src/app/train/jobConfig.ts new file mode 100644 index 00000000..fc97dcdf --- /dev/null +++ b/ui/src/app/train/jobConfig.ts @@ -0,0 +1,95 @@ +import { JobConfig, DatasetConfig } from '@/types'; + +export const defaultDatasetConfig: DatasetConfig = { + folder_path: '/path/to/images/folder', + mask_path: null, + mask_min_value: 0.1, + default_caption: '', + caption_ext: 'txt', + caption_dropout_rate: 0.05, + cache_latents_to_disk: false, + is_reg: false, + network_weight: 1, + resolution: [512, 768, 1024], +}; + +export const defaultJobConfig: JobConfig = { + job: 'extension', + config: { + name: 'my_first_flex_lora_v1', + process: [ + { + type: 'sd_trainer', + training_folder: 'output', + device: 'cuda:0', + network: { + type: 'lora', + linear: 16, + linear_alpha: 16, + }, + save: { + dtype: 'bf16', + save_every: 250, + max_step_saves_to_keep: 4, + push_to_hub: false, + }, + datasets: [ + defaultDatasetConfig + ], + train: { + batch_size: 1, + bypass_guidance_embedding: true, + steps: 2000, + gradient_accumulation: 1, + train_unet: true, + train_text_encoder: false, + gradient_checkpointing: true, + noise_scheduler: 'flowmatch', + optimizer: 'adamw8bit', + optimizer_params: { + weight_decay: 1e-4 + }, + lr: 0.0001, + ema_config: { + use_ema: true, + ema_decay: 0.99, + }, + dtype: 'bf16', + }, + model: { + name_or_path: 'ostris/Flex.1-alpha', + is_flux: true, + quantize: true, + quantize_te: true + }, + sample: { + sampler: 'flowmatch', + sample_every: 250, + width: 1024, + height: 1024, + prompts: [ + 'woman with red hair, playing chess at the park, bomb going off in the background', + 'a woman holding a coffee cup, in a beanie, sitting at a cafe', + 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini', + 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background', + 'a bear building a log cabin in the snow covered mountains', + 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker', + 'hipster man with a beard, building a chair, in a wood shop', + 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop', + "a man holding a sign that says, 'this is a sign'", + 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle', + ], + neg: '', + seed: 42, + walk_seed: true, + guidance_scale: 4, + sample_steps: 25, + }, + }, + ], + }, + meta: { + name: '[name]', + version: '1.0', + }, +}; diff --git a/ui/src/app/train/options.ts b/ui/src/app/train/options.ts index a32aa1f0..f856f21a 100644 --- a/ui/src/app/train/options.ts +++ b/ui/src/app/train/options.ts @@ -1,37 +1,41 @@ export interface Model { - name_or_path: string; - model_kwargs?: Record; - train_kwargs?: Record; + name_or_path: string; + defaults?: { [key: string]: any }; } export interface Option { - model: Model[]; + model: Model[]; } - export const options = { - model: [ - { - name_or_path: "ostris/Flex.1-alpha", - model_kwargs: { - "is_flux": true - }, - train_kwargs: { - "bypass_guidance_embedding": true - } - }, - { - name_or_path: "black-forest-labs/FLUX.1-dev", - model_kwargs: { - "is_flux": true - }, - }, - { - name_or_path: "Alpha-VLLM/Lumina-Image-2.0", - model_kwargs: { - "is_lumina2": true - }, - }, - ] - -} as Option; \ No newline at end of file + model: [ + { + name_or_path: 'ostris/Flex.1-alpha', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.is_flux': [true, false], + 'config.process[0].train.bypass_guidance_embedding': [true, false], + }, + }, + { + name_or_path: 'black-forest-labs/FLUX.1-dev', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.is_flux': [true, false], + }, + }, + { + name_or_path: 'Alpha-VLLM/Lumina-Image-2.0', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.is_lumina2': [true, false], + }, + }, + ], +} as Option; diff --git a/ui/src/app/train/page.tsx b/ui/src/app/train/page.tsx index a3eaee89..7298b1f1 100644 --- a/ui/src/app/train/page.tsx +++ b/ui/src/app/train/page.tsx @@ -1,34 +1,26 @@ 'use client'; +// todo update training folder from settings -import { useEffect, useState } from 'react'; +import { useEffect, useMemo, useState } from 'react'; import { useSearchParams, useRouter } from 'next/navigation'; import { options } from './options'; - -interface TrainingData { - modelConfig: { - name_or_path: string; - steps: number; - batchSize: number; - learningRate: number; - }; -} - -const defaultTrainingData: TrainingData = { - modelConfig: { - name_or_path: 'ostris/Flex.1-alpha', - steps: 100, - batchSize: 32, - learningRate: 0.001, - }, -}; +import { GPUApiResponse } from '@/types'; +import { defaultJobConfig, defaultDatasetConfig } from './jobConfig'; +import { JobConfig } from '@/types'; +import { objectCopy } from '@/utils/basic'; +import { useNestedState } from '@/utils/hooks'; +import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs'; +import Card from '@/components/Card'; +import { Trash, X } from 'lucide-react'; export default function TrainingForm() { const router = useRouter(); const searchParams = useSearchParams(); const runId = searchParams.get('id'); + const [gpuID, setGpuID] = useState(null); + const [gpuList, setGpuList] = useState([]); - const [name, setName] = useState(''); - const [trainingData, setTrainingData] = useState(defaultTrainingData); + const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); useEffect(() => { @@ -36,13 +28,42 @@ export default function TrainingForm() { fetch(`/api/training?id=${runId}`) .then(res => res.json()) .then(data => { - setName(data.name); - setTrainingData(JSON.parse(data.run_data)); + setGpuID(data.gpu_id); + setJobConfig(JSON.parse(data.job_config)); }) .catch(error => console.error('Error fetching training:', error)); } }, [runId]); + useEffect(() => { + const fetchGpuInfo = async () => { + try { + const response = await fetch('/api/gpu'); + + if (!response.ok) { + throw new Error(`HTTP error! Status: ${response.status}`); + } + + const data: GPUApiResponse = await response.json(); + setGpuList(data.gpus.map(gpu => gpu.index).sort()); + } catch (err) { + console.log(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`); + } finally { + // setLoading(false); + } + }; + + fetch('/api/settings') + .then(res => res.json()) + .then(data => { + setJobConfig(data.TRAINING_FOLDER, 'config.process[0].training_folder'); + }) + .catch(error => console.error('Error fetching settings:', error)); + + // Fetch immediately on component mount + fetchGpuInfo(); + }, []); + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); setStatus('saving'); @@ -55,8 +76,9 @@ export default function TrainingForm() { }, body: JSON.stringify({ id: runId, - name, - run_data: trainingData, + name: jobConfig.config.name, + gpu_id: gpuID, + job_config: jobConfig, }), }); @@ -75,64 +97,428 @@ export default function TrainingForm() { } }; - const updateSection = (section: keyof TrainingData, data: any) => { - setTrainingData(prev => ({ - ...prev, - [section]: { ...prev[section], ...data }, - })); - }; - - const modelOptions = options.model.map(model => model.name_or_path); - return ( -
-

{runId ? 'Edit Training Run' : 'New Training Run'}

+
+

{runId ? 'Edit Training Run' : 'New Training Run'}

-
- - setName(e.target.value)} - className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent" - placeholder="Enter training name" - required - /> -
+
+ + setJobConfig(value, 'config.name')} + placeholder="Enter training name" + required + /> + setGpuID(parseInt(value))} + options={gpuList.map(gpu => ({ value: `${gpu}`, label: `GPU #${gpu}` }))} + /> + - {/* Model Configuration Section */} -
-

Model Configuration

-
-
- - -
-
- - updateSection('modelConfig', { steps: Number(e.target.value) })} - className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg" + {/* Model Configuration Section */} + + { + // see if model changed + const currentModel = options.model.find( + model => model.name_or_path === jobConfig.config.process[0].model.name_or_path, + ); + if (!currentModel || currentModel.name_or_path === value) { + // model has not changed + return; + } + // revert defaults from previous model + for (const key in currentModel.defaults) { + setJobConfig(currentModel.defaults[key][1], key); + } + // set new model + setJobConfig(value, 'config.process[0].model.name_or_path'); + // update the defaults when a model is selected + const model = options.model.find(model => model.name_or_path === value); + if (model?.defaults) { + for (const key in model.defaults) { + setJobConfig(model.defaults[key][0], key); + } + } + }} + options={options.model.map(model => ({ + value: model.name_or_path, + label: model.name_or_path, + }))} + /> + + setJobConfig(value, 'config.process[0].model.quantize')} /> + setJobConfig(value, 'config.process[0].model.quantize_te')} + /> + + + {jobConfig.config.process[0].network?.linear && ( + + { + setJobConfig(value, 'config.process[0].network.linear'); + setJobConfig(value, 'config.process[0].network.linear_alpha'); + }} + placeholder="eg. 16" + min={1} + max={1024} + required + /> + + )} + + setJobConfig(value, 'config.process[0].save.dtype')} + options={[ + { value: 'bf16', label: 'BF16' }, + { value: 'fp16', label: 'FP16' }, + { value: 'fp32', label: 'FP32' }, + ]} + /> + setJobConfig(value, 'config.process[0].save.save_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].save.max_step_saves_to_keep')} + placeholder="eg. 4" + min={1} + required + /> + +
+
+ +
+
+ setJobConfig(value, 'config.process[0].train.batch_size')} + placeholder="eg. 4" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.gradient_accumulation')} + placeholder="eg. 1" + min={1} + required + /> + setJobConfig(value, 'config.process[0].train.steps')} + placeholder="eg. 2000" + min={1} + required + /> +
+
+ setJobConfig(value, 'config.process[0].train.optimizer')} + options={[ + { value: 'adamw8bit', label: 'AdamW8Bit' }, + { value: 'adafactor', label: 'Adafactor' }, + ]} + /> + setJobConfig(value, 'config.process[0].train.lr')} + placeholder="eg. 0.0001" + min={0} + required + /> + setJobConfig(value, 'config.process[0].train.optimizer_params.weight_decay')} + placeholder="eg. 0.0001" + min={0} + required + /> +
-
-
+ +
+
+ + <> + {jobConfig.config.process[0].datasets.map((dataset, i) => ( +
+ +

Dataset {i + 1}

+
+
+ setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)} + placeholder="eg. /path/to/images/folder" + required + /> + { + let setValue: string | null = value; + if (!setValue || setValue.trim() === '') { + setValue = null; + } + setJobConfig(setValue, `config.process[0].datasets[${i}].mask_path`); + }} + placeholder="eg. /path/to/masks/folder" + /> + setJobConfig(value, `config.process[0].datasets[${i}].mask_min_value`)} + placeholder="eg. 0.1" + min={0} + max={1} + required + /> +
+
+ setJobConfig(value, `config.process[0].datasets[${i}].default_caption`)} + placeholder="eg. A photo of a cat" + /> + setJobConfig(value, `config.process[0].datasets[${i}].caption_ext`)} + placeholder="eg. txt" + required + /> + setJobConfig(value, `config.process[0].datasets[${i}].caption_dropout_rate`)} + placeholder="eg. 0.05" + min={0} + required + /> +
+
+ + + setJobConfig(value, `config.process[0].datasets[${i}].cache_latents_to_disk`) + } + /> + setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} + /> + +
+
+ + {[256, 512, 768, 1024, 1280].map(res => ( + { + const resolutions = dataset.resolution.includes(res) + ? dataset.resolution.filter(r => r !== res) + : [...dataset.resolution, res]; + setJobConfig(resolutions, `config.process[0].datasets[${i}].resolution`); + }} + /> + ))} + +
+
+
+ ))} + + +
+
+
+ +
+
+ setJobConfig(value, 'config.process[0].sample.sample_every')} + placeholder="eg. 250" + min={1} + required + /> + setJobConfig(value, 'config.process[0].sample.sampler')} + options={[{ value: 'flowmatch', label: 'FlowMatch' }]} + /> +
+
+ setJobConfig(value, 'config.process[0].sample.guidance_scale')} + placeholder="eg. 1.0" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.sample_steps')} + placeholder="eg. 1" + className="pt-2" + min={1} + required + /> +
+
+ setJobConfig(value, 'config.process[0].sample.width')} + placeholder="eg. 1024" + min={256} + required + /> + setJobConfig(value, 'config.process[0].sample.height')} + placeholder="eg. 1024" + className="pt-2" + min={256} + required + /> +
+ +
+ setJobConfig(value, 'config.process[0].sample.seed')} + placeholder="eg. 0" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.walk_seed')} + /> +
+
+ + {jobConfig.config.process[0].sample.prompts.map((prompt, i) => ( +
+
+ setJobConfig(value, `config.process[0].sample.prompts[${i}]`)} + placeholder="Enter prompt" + required + /> +
+
+ +
+
+ ))} + +
+
+