mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
UI Bug fixes and initial windows support
This commit is contained in:
1
.github/ISSUE_TEMPLATE/bug_report.md
vendored
1
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -17,4 +17,3 @@ You verified that this is a bug and not a feature request or question by asking
|
|||||||
Yes/No
|
Yes/No
|
||||||
|
|
||||||
## Describe the bug
|
## Describe the bug
|
||||||
|
|
||||||
|
|||||||
@@ -33,4 +33,3 @@ huggingface_hub
|
|||||||
peft
|
peft
|
||||||
gradio
|
gradio
|
||||||
python-slugify
|
python-slugify
|
||||||
sqlite3
|
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
"build": "next build",
|
"build": "next build",
|
||||||
"start": "next start --port 8675",
|
"start": "next start --port 8675",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"update_db": "npx prisma generate ; npx prisma db push",
|
"update_db": "npx prisma generate && npx prisma db push",
|
||||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
import { NextResponse } from 'next/server';
|
import { NextResponse } from 'next/server';
|
||||||
import { exec } from 'child_process';
|
import { exec } from 'child_process';
|
||||||
import { promisify } from 'util';
|
import { promisify } from 'util';
|
||||||
|
import os from 'os';
|
||||||
|
|
||||||
const execAsync = promisify(exec);
|
const execAsync = promisify(exec);
|
||||||
|
|
||||||
export async function GET() {
|
export async function GET() {
|
||||||
try {
|
try {
|
||||||
|
// Get platform
|
||||||
|
const platform = os.platform();
|
||||||
|
const isWindows = platform === 'win32';
|
||||||
|
|
||||||
// Check if nvidia-smi is available
|
// Check if nvidia-smi is available
|
||||||
const hasNvidiaSmi = await checkNvidiaSmi();
|
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
||||||
|
|
||||||
if (!hasNvidiaSmi) {
|
if (!hasNvidiaSmi) {
|
||||||
return NextResponse.json({
|
return NextResponse.json({
|
||||||
@@ -18,7 +23,7 @@ export async function GET() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get GPU stats
|
// Get GPU stats
|
||||||
const gpuStats = await getGpuStats();
|
const gpuStats = await getGpuStats(isWindows);
|
||||||
|
|
||||||
return NextResponse.json({
|
return NextResponse.json({
|
||||||
hasNvidiaSmi: true,
|
hasNvidiaSmi: true,
|
||||||
@@ -37,20 +42,29 @@ export async function GET() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function checkNvidiaSmi(): Promise<boolean> {
|
async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
|
if (isWindows) {
|
||||||
|
// Check if nvidia-smi is available on Windows
|
||||||
|
// It's typically located in C:\Program Files\NVIDIA Corporation\NVSMI\nvidia-smi.exe
|
||||||
|
// but we'll just try to run it directly as it may be in PATH
|
||||||
|
await execAsync('nvidia-smi -L');
|
||||||
|
} else {
|
||||||
|
// Linux/macOS check
|
||||||
await execAsync('which nvidia-smi');
|
await execAsync('which nvidia-smi');
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getGpuStats() {
|
async function getGpuStats(isWindows: boolean) {
|
||||||
// Get detailed GPU information in JSON format including fan speed
|
// Command is the same for both platforms, but the path might be different
|
||||||
const { stdout } = await execAsync(
|
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
||||||
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits',
|
|
||||||
);
|
// Execute command
|
||||||
|
const { stdout } = await execAsync(command);
|
||||||
|
|
||||||
// Parse CSV output
|
// Parse CSV output
|
||||||
const gpus = stdout
|
const gpus = stdout
|
||||||
@@ -97,7 +111,7 @@ async function getGpuStats() {
|
|||||||
memory: parseInt(clockMemory),
|
memory: parseInt(clockMemory),
|
||||||
},
|
},
|
||||||
fan: {
|
fan: {
|
||||||
speed: parseInt(fanSpeed), // Fan speed as percentage
|
speed: parseInt(fanSpeed) || 0, // Some GPUs might not report fan speed, default to 0
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import { NextRequest, NextResponse } from 'next/server';
|
import { NextRequest, NextResponse } from 'next/server';
|
||||||
import { PrismaClient } from '@prisma/client';
|
import { PrismaClient } from '@prisma/client';
|
||||||
import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
import { TOOLKIT_ROOT } from '@/paths';
|
||||||
import { spawn } from 'child_process';
|
import { spawn } from 'child_process';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import fs from 'fs';
|
import fs from 'fs';
|
||||||
|
import os from 'os';
|
||||||
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
import { getTrainingFolder, getHFToken } from '@/server/settings';
|
||||||
|
|
||||||
const prisma = new PrismaClient();
|
const prisma = new PrismaClient();
|
||||||
@@ -51,10 +52,18 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
|||||||
let pythonPath = 'python';
|
let pythonPath = 'python';
|
||||||
// use .venv or venv if it exists
|
// use .venv or venv if it exists
|
||||||
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
if (fs.existsSync(path.join(TOOLKIT_ROOT, '.venv'))) {
|
||||||
|
if (os.platform() === 'win32') {
|
||||||
|
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'Scripts', 'python.exe');
|
||||||
|
} else {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
pythonPath = path.join(TOOLKIT_ROOT, '.venv', 'bin', 'python');
|
||||||
|
}
|
||||||
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
} else if (fs.existsSync(path.join(TOOLKIT_ROOT, 'venv'))) {
|
||||||
|
if (os.platform() === 'win32') {
|
||||||
|
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'Scripts', 'python.exe');
|
||||||
|
} else {
|
||||||
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
pythonPath = path.join(TOOLKIT_ROOT, 'venv', 'bin', 'python');
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
const runFilePath = path.join(TOOLKIT_ROOT, 'run.py');
|
||||||
if (!fs.existsSync(runFilePath)) {
|
if (!fs.existsSync(runFilePath)) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ const prisma = new PrismaClient();
|
|||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
const { searchParams } = new URL(request.url);
|
const { searchParams } = new URL(request.url);
|
||||||
const id = searchParams.get('id');
|
const id = searchParams.get('id');
|
||||||
|
console.log('ID:', id);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (id) {
|
if (id) {
|
||||||
@@ -18,6 +19,7 @@ export async function GET(request: Request) {
|
|||||||
const jobs = await prisma.job.findMany({
|
const jobs = await prisma.job.findMany({
|
||||||
orderBy: { created_at: 'desc' },
|
orderBy: { created_at: 'desc' },
|
||||||
});
|
});
|
||||||
|
console.log('Jobs:', jobs);
|
||||||
return NextResponse.json({ jobs: jobs });
|
return NextResponse.json({ jobs: jobs });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
export interface Model {
|
export interface Model {
|
||||||
name_or_path: string;
|
name_or_path: string;
|
||||||
|
dev_only?: boolean;
|
||||||
defaults?: { [key: string]: any };
|
defaults?: { [key: string]: any };
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,5 +38,11 @@ export const options = {
|
|||||||
'config.process[0].model.is_lumina2': [true, false],
|
'config.process[0].model.is_lumina2': [true, false],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name_or_path: 'ostris/objective-reality',
|
||||||
|
dev_only: true,
|
||||||
|
defaults: {
|
||||||
|
},
|
||||||
|
},
|
||||||
],
|
],
|
||||||
} as Option;
|
} as Option;
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import { TopBar, MainContent } from '@/components/layout';
|
|||||||
import { Button } from '@headlessui/react';
|
import { Button } from '@headlessui/react';
|
||||||
import { FaChevronLeft } from 'react-icons/fa';
|
import { FaChevronLeft } from 'react-icons/fa';
|
||||||
|
|
||||||
|
const isDev = process.env.NODE_ENV === 'development';
|
||||||
|
|
||||||
export default function TrainingForm() {
|
export default function TrainingForm() {
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
@@ -42,9 +44,11 @@ export default function TrainingForm() {
|
|||||||
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
for (let i = 0; i < jobConfig.config.process[0].datasets.length; i++) {
|
||||||
const dataset = jobConfig.config.process[0].datasets[i];
|
const dataset = jobConfig.config.process[0].datasets[i];
|
||||||
if (dataset.folder_path === defaultDatasetPath) {
|
if (dataset.folder_path === defaultDatasetPath) {
|
||||||
|
if (datasetOptions.length > 0) {
|
||||||
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
|
setJobConfig(datasetOptions[0].value, `config.process[0].datasets[${i}].folder_path`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
}, [datasets, settings, isSettingsLoaded, datasetFetchStatus]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -196,13 +200,22 @@ export default function TrainingForm() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
options={options.model.map(model => ({
|
options={
|
||||||
|
options.model
|
||||||
|
.map(model => {
|
||||||
|
if (model.dev_only && !isDev) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return {
|
||||||
value: model.name_or_path,
|
value: model.name_or_path,
|
||||||
label: model.name_or_path,
|
label: model.name_or_path,
|
||||||
}))}
|
};
|
||||||
|
})
|
||||||
|
.filter(x => x) as { value: string; label: string }[]
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
<FormGroup label="Quantize">
|
<FormGroup label="Quantize">
|
||||||
<div className='grid grid-cols-2 gap-2'>
|
<div className="grid grid-cols-2 gap-2">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
label="Transformer"
|
label="Transformer"
|
||||||
checked={jobConfig.config.process[0].model.quantize}
|
checked={jobConfig.config.process[0].model.quantize}
|
||||||
@@ -356,7 +369,7 @@ export default function TrainingForm() {
|
|||||||
<FormGroup label="EMA (Exponential Moving Average)">
|
<FormGroup label="EMA (Exponential Moving Average)">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
label="Use EMA"
|
label="Use EMA"
|
||||||
className='pt-1'
|
className="pt-1"
|
||||||
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
checked={jobConfig.config.process[0].train.ema_config?.use_ema || false}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||||
/>
|
/>
|
||||||
@@ -531,7 +544,7 @@ export default function TrainingForm() {
|
|||||||
value={jobConfig.config.process[0].sample.width}
|
value={jobConfig.config.process[0].sample.width}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
onChange={value => setJobConfig(value, 'config.process[0].sample.width')}
|
||||||
placeholder="eg. 1024"
|
placeholder="eg. 1024"
|
||||||
min={256}
|
min={0}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
@@ -540,7 +553,7 @@ export default function TrainingForm() {
|
|||||||
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
onChange={value => setJobConfig(value, 'config.process[0].sample.height')}
|
||||||
placeholder="eg. 1024"
|
placeholder="eg. 1024"
|
||||||
className="pt-2"
|
className="pt-2"
|
||||||
min={256}
|
min={0}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import React, { useState, useEffect } from 'react';
|
import React, { useState, useEffect, useRef } from 'react';
|
||||||
import { GPUApiResponse } from '@/types';
|
import { GPUApiResponse } from '@/types';
|
||||||
import Loading from '@/components/Loading';
|
import Loading from '@/components/Loading';
|
||||||
import GPUWidget from '@/components/GPUWidget';
|
import GPUWidget from '@/components/GPUWidget';
|
||||||
@@ -8,10 +8,15 @@ const GpuMonitor: React.FC = () => {
|
|||||||
const [loading, setLoading] = useState<boolean>(true);
|
const [loading, setLoading] = useState<boolean>(true);
|
||||||
const [error, setError] = useState<string | null>(null);
|
const [error, setError] = useState<string | null>(null);
|
||||||
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
|
const [lastUpdated, setLastUpdated] = useState<Date | null>(null);
|
||||||
|
const isFetchingGpuRef = useRef(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchGpuInfo = async () => {
|
const fetchGpuInfo = async () => {
|
||||||
try {
|
try {
|
||||||
|
if (isFetchingGpuRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
isFetchingGpuRef.current = true;
|
||||||
const response = await fetch('/api/gpu');
|
const response = await fetch('/api/gpu');
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@@ -25,6 +30,7 @@ const GpuMonitor: React.FC = () => {
|
|||||||
} catch (err) {
|
} catch (err) {
|
||||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||||
} finally {
|
} finally {
|
||||||
|
isFetchingGpuRef.current = false;
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -98,8 +98,8 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
|||||||
<div>
|
<div>
|
||||||
<p className="text-xs text-gray-400">Power Draw</p>
|
<p className="text-xs text-gray-400">Power Draw</p>
|
||||||
<p className="text-sm text-gray-200">
|
<p className="text-sm text-gray-200">
|
||||||
{gpu.power.draw.toFixed(1)}W
|
{gpu.power.draw?.toFixed(1)}W
|
||||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit.toFixed(1)}W</span>
|
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user