mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
UI Bug fixes and initial windows support
This commit is contained in:
1
.github/ISSUE_TEMPLATE/bug_report.md
vendored
1
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -17,4 +17,3 @@ You verified that this is a bug and not a feature request or question by asking
|
||||
Yes/No
|
||||
|
||||
## Describe the bug
|
||||
|
||||
|
||||
@@ -33,4 +33,3 @@ huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
sqlite3
|
||||
@@ -7,7 +7,7 @@
|
||||
"build": "next build",
|
||||
"start": "next start --port 8675",
|
||||
"lint": "next lint",
|
||||
"update_db": "npx prisma generate ; npx prisma db push",
|
||||
"update_db": "npx prisma generate && npx prisma db push",
|
||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||
},
|
||||
"dependencies": {
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { exec } from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
import os from 'os';
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
// Get platform
|
||||
const platform = os.platform();
|
||||
const isWindows = platform === 'win32';
|
||||
|
||||
// Check if nvidia-smi is available
|
||||
const hasNvidiaSmi = await checkNvidiaSmi();
|
||||
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
||||
|
||||
if (!hasNvidiaSmi) {
|
||||
return NextResponse.json({
|
||||
@@ -18,7 +23,7 @@ export async function GET() {
|
||||
}
|
||||
|
||||
// Get GPU stats
|
||||
const gpuStats = await getGpuStats();
|
||||
const gpuStats = await getGpuStats(isWindows);
|
||||
|
||||
return NextResponse.json({
|
||||
hasNvidiaSmi: true,
|
||||
@@ -37,20 +42,29 @@ export async function GET() {
|
||||
}
|
||||
}
|
||||
|
||||
async function checkNvidiaSmi(): Promise<boolean> {
|
||||
async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
|
||||
try {
|
||||
if (isWindows) {
|
||||
// Check if nvidia-smi is available on Windows
|
||||
// It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
|
||||
// but we'll just try to run it directly as it may be in PATH
|
||||
await execAsync('nvidia-smi -L');
|
||||
} else {
|
||||
// Linux/macOS check
|
||||
await execAsync('which nvidia-smi');
|
||||
}
|
||||
return true;
|
||||
} catch (error) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
async function getGpuStats() {
|
||||
// Get detailed GPU information in JSON format including fan speed
|
||||
const { stdout } = await execAsync(
|
||||
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits',
|
||||
);
|
||||
async function getGpuStats(isWindows: boolean) {
|
||||
// Command is the same for both platforms, but the path might be different
|
||||
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
||||
|
||||
// Execute command
|
||||
const { stdout } = await execAsync(command);
|
||||
|
||||
// Parse CSV output
|
||||
const gpus = stdout
|
||||
@@ -97,7 +111,7 @@ async function getGpuStats() {
|
||||
memory: parseInt(clockMemory),
|
||||
},
|
||||
fan: {
|
||||
speed: parseInt(fanSpeed), // Fan speed as percentage
|
||||
speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { TOOLKIT_ROOT } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import os from 'os';
|
||||
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
@@ -51,10 +52,18 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
let pythonPath = 'python';
|
||||
// use .venv or venv if it exists
|
||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||
if (os.platform() === 'win32') {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
|
||||
} else {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||
}
|
||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||
if (os.platform() === 'win32') {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
|
||||
} else {
|
||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||
}
|
||||
}
|
||||
|
||||
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
||||
if (!fs.existsSync(runFilePath)) {
|
||||
|
||||
@@ -6,6 +6,7 @@ const prisma = new PrismaClient();
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const id = searchParams.get('id');
|
||||
console.log('ID:', id);
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
@@ -18,6 +19,7 @@ export async function GET(request: Request) {
|
||||
const jobs = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
console.log('Jobs:', jobs);
|
||||
return NextResponse.json({ jobs: jobs });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
export interface Model {
|
||||
name_or_path: string;
|
||||
dev_only?: boolean;
|
||||
defaults?: { [key: string]: any };
|
||||
}
|
||||
|
||||
@@ -37,5 +38,11 @@ export const options = {
|
||||
'config.process[0].model.is_lumina2': [true, false],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'ostris/objective-reality',
|
||||
dev_only: true,
|
||||
defaults: {
|
||||
},
|
||||
},
|
||||
],
|
||||
} as Option;
|
||||
|
||||
@@ -18,6 +18,8 @@ import { TopBar, MainContent } from '@/components/layout';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
|
||||
const isDev = process.env.NODE_ENV === 'development';
|
||||
|
||||
export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
@@ -42,9 +44,11 @@ export default function TrainingForm() {
|
||||
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
||||
const dataset = jobConfig.config.process[0].datasets[i];
|
||||
if (dataset.folder_path === defaultDatasetPath) {
|
||||
if (datasetOptions.length > 0) {
|
||||
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -196,13 +200,22 @@ export default function TrainingForm() {
|
||||
}
|
||||
}
|
||||
}}
|
||||
options={options.model.map(model => ({
|
||||
options={
|
||||
options.model
|
||||
.map(model => {
|
||||
if (model.dev_only && !isDev) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
value: model.name_or_path,
|
||||
label: model.name_or_path,
|
||||
}))}
|
||||
};
|
||||
})
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
/>
|
||||
<FormGroup label="Quantize">
|
||||
<div className='grid grid-cols-2 gap-2'>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Transformer"
|
||||
checked={jobConfig.config.process[0].model.quantize}
|
||||
@@ -356,7 +369,7 @@ export default function TrainingForm() {
|
||||
<FormGroup label="EMA (Exponential Moving Average)">
|
||||
<Checkbox
|
||||
label="Use EMA"
|
||||
className='pt-1'
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
@@ -531,7 +544,7 @@ export default function TrainingForm() {
|
||||
value={jobConfig.config.process[0].sample.width}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
||||
placeholder="eg. 1024"
|
||||
min={256}
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
@@ -540,7 +553,7 @@ export default function TrainingForm() {
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
||||
placeholder="eg. 1024"
|
||||
className="pt-2"
|
||||
min={256}
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import { GPUApiResponse } from '@/types';
|
||||
import Loading from '@/components/Loading';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
@@ -8,10 +8,15 @@ const GpuMonitor: React.FC = () => {
|
||||
const [loading, setLoading] = useState<boolean>(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
|
||||
const isFetchingGpuRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchGpuInfo = async () => {
|
||||
try {
|
||||
if (isFetchingGpuRef.current) {
|
||||
return;
|
||||
}
|
||||
isFetchingGpuRef.current = true;
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -25,6 +30,7 @@ const GpuMonitor: React.FC = () => {
|
||||
} catch (err) {
|
||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
} finally {
|
||||
isFetchingGpuRef.current = false;
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -98,8 +98,8 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Power Draw</p>
|
||||
<p className="text-sm text-gray-200">
|
||||
{gpu.power.draw.toFixed(1)}W
|
||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit.toFixed(1)}W</span>
|
||||
{gpu.power.draw?.toFixed(1)}W
|
||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user