UI Bug fixes and initial windows support

This commit is contained in:
Jaret Burkett
2025-02-24 08:15:22 -07:00
parent f0fbd8bb53
commit 093f14ac19
10 changed files with 88 additions and 39 deletions

View File

@@ -17,4 +17,3 @@ You verified that this is a bug and not a feature request or question by asking
Yes/No
## Describe the bug

View File

@@ -32,5 +32,4 @@ sentencepiece
huggingface_hub
peft
gradio
python-slugify
sqlite3
python-slugify

View File

@@ -7,7 +7,7 @@
"build": "next build",
"start": "next start --port 8675",
"lint": "next lint",
"update_db": "npx prisma generate ; npx prisma db push",
"update_db": "npx prisma generate && npx prisma db push",
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
},
"dependencies": {

View File

@@ -1,13 +1,18 @@
import { NextResponse } from 'next/server';
import { exec } from 'child_process';
import { promisify } from 'util';
import os from 'os';
const execAsync = promisify(exec);
export async function GET() {
try {
// Get platform
const platform = os.platform();
const isWindows = platform === 'win32';
// Check if nvidia-smi is available
const hasNvidiaSmi = await checkNvidiaSmi();
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
if (!hasNvidiaSmi) {
return NextResponse.json({
@@ -18,7 +23,7 @@ export async function GET() {
}
// Get GPU stats
const gpuStats = await getGpuStats();
const gpuStats = await getGpuStats(isWindows);
return NextResponse.json({
hasNvidiaSmi: true,
@@ -37,20 +42,29 @@ export async function GET() {
}
}
async function checkNvidiaSmi(): Promise<boolean> {
async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
try {
await execAsync('which nvidia-smi');
if (isWindows) {
// Check if nvidia-smi is available on Windows
// It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
// but we'll just try to run it directly as it may be in PATH
await execAsync('nvidia-smi -L');
} else {
// Linux/macOS check
await execAsync('which nvidia-smi');
}
return true;
} catch (error) {
return false;
}
}
async function getGpuStats() {
// Get detailed GPU information in JSON format including fan speed
const { stdout } = await execAsync(
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits',
);
async function getGpuStats(isWindows: boolean) {
// Command is the same for both platforms, but the path might be different
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
// Execute command
const { stdout } = await execAsync(command);
// Parse CSV output
const gpus = stdout
@@ -97,7 +111,7 @@ async function getGpuStats() {
memory: parseInt(clockMemory),
},
fan: {
speed: parseInt(fanSpeed), // Fan speed as percentage
speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
},
};
});

View File

@@ -1,9 +1,10 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
import { TOOLKIT_ROOT } from '@/paths';
import { spawn } from 'child_process';
import path from 'path';
import fs from 'fs';
import os from 'os';
import { getTrainingFolder, getHFToken } from '@/server/settings';
const prisma = new PrismaClient();
@@ -51,9 +52,17 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
let pythonPath = 'python';
// use .venv or venv if it exists
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
if (os.platform() === 'win32') {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
} else {
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
}
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
if (os.platform() === 'win32') {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
} else {
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
}
}
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');

View File

@@ -6,6 +6,7 @@ const prisma = new PrismaClient();
export async function GET(request: Request) {
const { searchParams } = new URL(request.url);
const id = searchParams.get('id');
console.log('ID:', id);
try {
if (id) {
@@ -18,6 +19,7 @@ export async function GET(request: Request) {
const jobs = await prisma.job.findMany({
orderBy: { created_at: 'desc' },
});
console.log('Jobs:', jobs);
return NextResponse.json({ jobs: jobs });
} catch (error) {
console.error(error);

View File

@@ -1,5 +1,6 @@
export interface Model {
name_or_path: string;
dev_only?: boolean;
defaults?: { [key: string]: any };
}
@@ -37,5 +38,11 @@ export const options = {
'config.process[0].model.is_lumina2': [true, false],
},
},
{
name_or_path: 'ostris/objective-reality',
dev_only: true,
defaults: {
},
},
],
} as Option;

View File

@@ -18,6 +18,8 @@ import { TopBar, MainContent } from '@/components/layout';
import { Button } from '@headlessui/react';
import { FaChevronLeft } from 'react-icons/fa';
const isDev = process.env.NODE_ENV === 'development';
export default function TrainingForm() {
const router = useRouter();
const searchParams = useSearchParams();
@@ -42,7 +44,9 @@ export default function TrainingForm() {
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
const dataset = jobConfig.config.process[0].datasets[i];
if (dataset.folder_path === defaultDatasetPath) {
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
if (datasetOptions.length > 0) {
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
}
}
}
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
@@ -196,23 +200,32 @@ export default function TrainingForm() {
}
}
}}
options={options.model.map(model => ({
value: model.name_or_path,
label: model.name_or_path,
}))}
options={
options.model
.map(model => {
if (model.dev_only && !isDev) {
return null;
}
return {
value: model.name_or_path,
label: model.name_or_path,
};
})
.filter(x => x) as { value: string; label: string }[]
}
/>
<FormGroup label="Quantize">
<div className='grid grid-cols-2 gap-2'>
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Transformer"
checked={jobConfig.config.process[0].model.quantize}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize')}
/>
<Checkbox
label="Text Encoder"
checked={jobConfig.config.process[0].model.quantize_te}
onChange={value => setJobConfig(value, 'config.process[0].model.quantize_te')}
/>
</div>
</FormGroup>
</Card>
@@ -356,7 +369,7 @@ export default function TrainingForm() {
<FormGroup label="EMA (Exponential Moving Average)">
<Checkbox
label="Use EMA"
className='pt-1'
className="pt-1"
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/>
@@ -531,7 +544,7 @@ export default function TrainingForm() {
value={jobConfig.config.process[0].sample.width}
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
placeholder="eg. 1024"
min={256}
min={0}
required
/>
<NumberInput
@@ -540,7 +553,7 @@ export default function TrainingForm() {
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
placeholder="eg. 1024"
className="pt-2"
min={256}
min={0}
required
/>
</div>

View File

@@ -1,4 +1,4 @@
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useRef } from 'react';
import { GPUApiResponse } from '@/types';
import Loading from '@/components/Loading';
import GPUWidget from '@/components/GPUWidget';
@@ -8,10 +8,15 @@ const GpuMonitor: React.FC = () => {
const [loading, setLoading] = useState<boolean>(true);
const [error, setError] = useState<string | null>(null);
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
const isFetchingGpuRef = useRef(false);
useEffect(() => {
const fetchGpuInfo = async () => {
try {
if (isFetchingGpuRef.current) {
return;
}
isFetchingGpuRef.current = true;
const response = await fetch('/api/gpu');
if (!response.ok) {
@@ -25,6 +30,7 @@ const GpuMonitor: React.FC = () => {
} catch (err) {
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
} finally {
isFetchingGpuRef.current = false;
setLoading(false);
}
};

View File

@@ -98,8 +98,8 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
<div>
<p className="text-xs text-gray-400">Power Draw</p>
<p className="text-sm text-gray-200">
{gpu.power.draw.toFixed(1)}W
<span className="text-gray-400 text-xs"> / {gpu.power.limit.toFixed(1)}W</span>
{gpu.power.draw?.toFixed(1)}W
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
</p>
</div>
</div>