Added a way to secure the UI. Plus various bug fixes and quality of life updates

This commit is contained in:
Jaret Burkett
2025-03-20 08:07:09 -06:00
parent bbfd6ef0fe
commit 3a6b24f4c8
47 changed files with 618 additions and 378 deletions

File diff suppressed because one or more lines are too long

View File

@@ -15,7 +15,8 @@ class UITrainer(SDTrainer):
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
if not os.path.exists(self.sqlite_db_path):
raise Exception(f"SQLite database not found at {self.sqlite_db_path}")
raise Exception(
f"SQLite database not found at {self.sqlite_db_path}")
print(f"Using SQLite database at {self.sqlite_db_path}")
self.job_id = os.environ.get("AITK_JOB_ID", None)
self.job_id = self.job_id.strip() if self.job_id is not None else None
@@ -147,6 +148,8 @@ class UITrainer(SDTrainer):
try:
await asyncio.gather(*self._async_tasks)
except Exception as e:
print(f"Error waiting for async operations: {e}")
finally:
# Clear the task list after completion
self._async_tasks.clear()

View File

@@ -6,6 +6,7 @@
"dev": "next dev --turbopack",
"build": "next build",
"start": "next start --port 8675",
"build_and_start": "npm install && npm run update_db && npm run build && npm run start",
"lint": "next lint",
"update_db": "npx prisma generate && npx prisma db push",
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""

View File

@@ -0,0 +1,6 @@
import { NextResponse } from 'next/server';
export async function GET() {
// if this gets hit, auth has already been verified
return NextResponse.json({ isAuthenticated: true });
}

View File

@@ -12,10 +12,7 @@ export async function POST(request: Request) {
try {
// Check if folder exists
if (!fs.existsSync(datasetFolder)) {
return NextResponse.json(
{ error: `Folder '${datasetName}' not found` },
{ status: 404 }
);
return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
}
// Find all images recursively
@@ -23,16 +20,13 @@ export async function POST(request: Request) {
// Format response
const result = imageFiles.map(imgPath => ({
img_path: imgPath
img_path: imgPath,
}));
return NextResponse.json({ images: result });
} catch (error) {
console.error('Error finding images:', error);
return NextResponse.json(
{ error: 'Failed to process request' },
{ status: 500 }
);
return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
}
}

View File

@@ -16,7 +16,8 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
const allowedDirs = [datasetRoot, trainingRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
const isAllowed =
allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
@@ -62,7 +63,7 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
'Accept-Ranges': 'bytes',
'Cache-Control': 'public, max-age=86400',
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
'X-Content-Type-Options': 'nosniff'
'X-Content-Type-Options': 'nosniff',
};
if (range) {
@@ -70,12 +71,12 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
const parts = range.replace(/bytes=/, '').split('-');
const start = parseInt(parts[0], 10);
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
const chunkSize = (end - start) + 1;
const chunkSize = end - start + 1;
const fileStream = fs.createReadStream(decodedFilePath, {
start,
end,
highWaterMark: 64 * 1024 // 64KB buffer
highWaterMark: 64 * 1024, // 64KB buffer
});
return new NextResponse(fileStream as any, {
@@ -83,19 +84,19 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
headers: {
...commonHeaders,
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
'Content-Length': String(chunkSize)
'Content-Length': String(chunkSize),
},
});
} else {
// For full file download, read directly without streaming wrapper
const fileStream = fs.createReadStream(decodedFilePath, {
highWaterMark: 64 * 1024 // 64KB buffer
highWaterMark: 64 * 1024, // 64KB buffer
});
return new NextResponse(fileStream as any, {
headers: {
...commonHeaders,
'Content-Length': String(stat.size)
'Content-Length': String(stat.size),
},
});
}

View File

@@ -61,7 +61,8 @@ async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
async function getGpuStats(isWindows: boolean) {
// Command is the same for both platforms, but the path might be different
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
const command =
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
// Execute command
const { stdout } = await execAsync(command);

View File

@@ -17,7 +17,6 @@ export async function POST(request: Request) {
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
}
// check for caption
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
// save caption to file

View File

@@ -6,7 +6,6 @@ const prisma = new PrismaClient();
export async function GET(request: Request) {
const { searchParams } = new URL(request.url);
const id = searchParams.get('id');
console.log('ID:', id);
try {
if (id) {
@@ -19,7 +18,6 @@ export async function GET(request: Request) {
const jobs = await prisma.job.findMany({
orderBy: { created_at: 'desc' },
});
console.log('Jobs:', jobs);
return NextResponse.json({ jobs: jobs });
} catch (error) {
console.error(error);

View File

@@ -1,7 +1,7 @@
import { NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
import {flushCache} from '@/server/settings';
import { flushCache } from '@/server/settings';
const prisma = new PrismaClient();

View File

@@ -6,6 +6,7 @@ import DatasetImageCard from '@/components/DatasetImageCard';
import { Button } from '@headlessui/react';
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
import { TopBar, MainContent } from '@/components/layout';
import { apiClient } from '@/utils/api';
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
@@ -15,15 +16,11 @@ export default function DatasetPage({ params }: { params: { datasetName: string
const refreshImageList = (dbName: string) => {
setStatus('loading');
fetch('/api/datasets/listImages', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ datasetName: dbName }),
})
.then(res => res.json())
.then(data => {
console.log('Fetching images for dataset:', dbName);
apiClient
.post('/api/datasets/listImages', { datasetName: dbName })
.then((res: any) => {
const data = res.data;
console.log('Images:', data.images);
// sort
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));

View File

@@ -10,6 +10,7 @@ import { FaRegTrashAlt } from 'react-icons/fa';
import { openConfirm } from '@/components/ConfirmModal';
import { TopBar, MainContent } from '@/components/layout';
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
import { apiClient } from '@/utils/api';
export default function Datasets() {
const { datasets, status, refreshDatasets } = useDatasetList();
@@ -54,16 +55,10 @@ export default function Datasets() {
type: 'warning',
confirmText: 'Delete',
onConfirm: () => {
fetch('/api/datasets/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ name: datasetName }),
})
.then(res => res.json())
.then(data => {
console.log('Dataset deleted:', data);
apiClient
.post('/api/datasets/delete', { name: datasetName })
.then(() => {
console.log('Dataset deleted:', datasetName);
refreshDatasets();
})
.catch(error => {
@@ -76,14 +71,7 @@ export default function Datasets() {
const handleCreateDataset = async (e: React.FormEvent) => {
e.preventDefault();
try {
const response = await fetch('/api/datasets/create', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ name: newDatasetName }),
});
const data = await response.json();
const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
console.log('New dataset created:', data);
refreshDatasets();
setNewDatasetName('');

View File

@@ -33,11 +33,7 @@ const yamlConfig: YAML.DocumentOptions &
directives: true,
};
export default function AdvancedJob({
jobConfig,
setJobConfig,
settings,
}: Props) {
export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
const [editorValue, setEditorValue] = useState<string>('');
const lastJobConfigUpdateStringRef = useRef('');
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);

View File

@@ -30,7 +30,7 @@ export const defaultJobConfig: JobConfig = {
linear: 32,
linear_alpha: 32,
lokr_full_rank: true,
lokr_factor: -1
lokr_factor: -1,
},
save: {
dtype: 'bf16',
@@ -39,9 +39,7 @@ export const defaultJobConfig: JobConfig = {
save_format: 'diffusers',
push_to_hub: false,
},
datasets: [
defaultDatasetConfig
],
datasets: [defaultDatasetConfig],
train: {
batch_size: 1,
bypass_guidance_embedding: true,
@@ -55,7 +53,7 @@ export const defaultJobConfig: JobConfig = {
timestep_type: 'sigmoid',
content_or_style: 'balanced',
optimizer_params: {
weight_decay: 1e-4
weight_decay: 1e-4,
},
unload_text_encoder: false,
lr: 0.0001,
@@ -66,8 +64,7 @@ export const defaultJobConfig: JobConfig = {
dtype: 'bf16',
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person'
diff_output_preservation_class: 'person',
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -12,7 +12,7 @@ export const modelArchs = [
{ name: 'flux', label: 'Flux.1' },
{ name: 'wan21', label: 'Wan 2.1' },
{ name: 'lumina2', label: 'Lumina2' },
]
];
export const isVideoModelFromArch = (arch: string) => {
const videoArches = ['wan21'];

View File

@@ -19,6 +19,7 @@ import { Button } from '@headlessui/react';
import { FaChevronLeft } from 'react-icons/fa';
import SimpleJob from './SimpleJob';
import AdvancedJob from './AdvancedJob';
import { apiClient } from '@/utils/api';
const isDev = process.env.NODE_ENV === 'development';
@@ -56,12 +57,13 @@ export default function TrainingForm() {
useEffect(() => {
if (runId) {
fetch(`/api/jobs?id=${runId}`)
.then(res => res.json())
apiClient
.get(`/api/jobs?id=${runId}`)
.then(res => res.data)
.then(data => {
console.log('Training:', data);
setGpuIDs(data.gpu_ids);
setJobConfig(JSON.parse(data.job_config));
// setJobConfig(data.name, 'config.name');
})
.catch(error => console.error('Error fetching training:', error));
}
@@ -85,33 +87,30 @@ export default function TrainingForm() {
if (status === 'saving') return;
setStatus('saving');
try {
const response = await fetch('/api/jobs', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
id: runId,
name: jobConfig.config.name,
gpu_ids: gpuIDs,
job_config: jobConfig,
}),
});
if (!response.ok) throw new Error('Failed to save training');
setStatus('success');
if (!runId) {
const data = await response.json();
router.push(`/jobs/${data.id}`);
}
setTimeout(() => setStatus('idle'), 2000);
} catch (error) {
console.error('Error saving training:', error);
setStatus('error');
setTimeout(() => setStatus('idle'), 2000);
}
apiClient
.post('/api/jobs', {
id: runId,
name: jobConfig.config.name,
gpu_ids: gpuIDs,
job_config: jobConfig,
})
.then(res => {
setStatus('success');
if (runId) {
router.push(`/jobs/${runId}`);
} else {
router.push(`/jobs/${res.data.id}`);
}
})
.catch(error => {
console.error('Error saving training:', error);
setStatus('error');
})
.finally(() =>
setTimeout(() => {
setStatus('idle');
}, 2000),
);
};
const handleSubmit = async (e: React.FormEvent) => {

View File

@@ -13,10 +13,7 @@ export default function Dashboard() {
</div>
<div className="flex-1"></div>
<div>
<Link
href="/jobs/new"
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
>
<Link href="/jobs/new" className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md">
New Training Job
</Link>
</div>

View File

@@ -6,6 +6,7 @@ import { ThemeProvider } from '@/components/ThemeProvider';
import ConfirmModal from '@/components/ConfirmModal';
import SampleImageModal from '@/components/SampleImageModal';
import { Suspense } from 'react';
import AuthWrapper from '@/components/AuthWrapper';
const inter = Inter({ subsets: ['latin'] });
@@ -15,6 +16,9 @@ export const metadata: Metadata = {
};
export default function RootLayout({ children }: { children: React.ReactNode }) {
// Check if the AI_TOOLKIT_AUTH environment variable is set
const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false;
return (
<html lang="en" className="dark">
<head>
@@ -22,13 +26,14 @@ export default function RootLayout({ children }: { children: React.ReactNode })
</head>
<body className={inter.className}>
<ThemeProvider>
<div className="flex h-screen bg-gray-950">
<Sidebar />
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
<Suspense>{children}</Suspense>
</main>
</div>
<AuthWrapper authRequired={authRequired}>
<div className="flex h-screen bg-gray-950">
<Sidebar />
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
<Suspense>{children}</Suspense>
</main>
</div>
</AuthWrapper>
</ThemeProvider>
<ConfirmModal />
<SampleImageModal />

View File

@@ -3,6 +3,7 @@
import { useEffect, useState } from 'react';
import useSettings from '@/hooks/useSettings';
import { TopBar, MainContent } from '@/components/layout';
import { apiClient } from '@/utils/api';
export default function Settings() {
const { settings, setSettings } = useSettings();
@@ -12,24 +13,18 @@ export default function Settings() {
e.preventDefault();
setStatus('saving');
try {
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(settings),
apiClient
.post('/api/settings', settings)
.then(() => {
setStatus('success');
})
.catch(error => {
console.error('Error saving settings:', error);
setStatus('error');
})
.finally(() => {
setTimeout(() => setStatus('idle'), 2000);
});
if (!response.ok) throw new Error('Failed to save settings');
setStatus('success');
setTimeout(() => setStatus('idle'), 2000);
} catch (error) {
console.error('Error saving settings:', error);
setStatus('error');
setTimeout(() => setStatus('idle'), 2000);
}
};
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {

View File

@@ -4,7 +4,7 @@ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/re
import { FaUpload } from 'react-icons/fa';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import axios from 'axios';
import { apiClient } from '@/utils/api';
export interface AddImagesModalState {
datasetName: string;
@@ -15,7 +15,7 @@ export const addImagesModalState = createGlobalState<AddImagesModalState | null>
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
addImagesModalState.set({ datasetName, onComplete });
}
};
export default function AddImagesModal() {
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
@@ -36,46 +36,49 @@ export default function AddImagesModal() {
}
};
const onDrop = useCallback(async (acceptedFiles: File[]) => {
if (acceptedFiles.length === 0) return;
const onDrop = useCallback(
async (acceptedFiles: File[]) => {
if (acceptedFiles.length === 0) return;
setIsUploading(true);
setUploadProgress(0);
const formData = new FormData();
acceptedFiles.forEach(file => {
formData.append('files', file);
});
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
try {
await axios.post(`/api/datasets/upload`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
onUploadProgress: (progressEvent) => {
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
setUploadProgress(percentCompleted);
},
timeout: 0, // Disable timeout
});
onDone();
} catch (error) {
console.error('Upload failed:', error);
} finally {
setIsUploading(false);
setIsUploading(true);
setUploadProgress(0);
}
}, [addImagesModalInfo]);
const formData = new FormData();
acceptedFiles.forEach(file => {
formData.append('files', file);
});
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
try {
await apiClient.post(`/api/datasets/upload`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
onUploadProgress: progressEvent => {
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
setUploadProgress(percentCompleted);
},
timeout: 0, // Disable timeout
});
onDone();
} catch (error) {
console.error('Upload failed:', error);
} finally {
setIsUploading(false);
setUploadProgress(0);
}
},
[addImagesModalInfo],
);
const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop,
accept: {
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
'text/*': ['.txt']
'text/*': ['.txt'],
},
multiple: true
multiple: true,
});
return (
@@ -105,22 +108,15 @@ export default function AddImagesModal() {
<input {...getInputProps()} />
<FaUpload className="size-8 mb-3 text-gray-400" />
<p className="text-sm text-gray-200 text-center">
{isDragActive
? 'Drop the files here...'
: 'Drag & drop files here, or click to select files'}
{isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}
</p>
</div>
{isUploading && (
<div className="mt-4">
<div className="w-full bg-gray-700 rounded-full h-2.5">
<div
className="bg-blue-600 h-2.5 rounded-full"
style={{ width: `${uploadProgress}%` }}
></div>
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
</div>
<p className="text-sm text-gray-300 mt-2 text-center">
Uploading... {uploadProgress}%
</p>
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
</div>
)}
</div>

View File

@@ -0,0 +1,163 @@
'use client';
import { useState, useEffect, useRef } from 'react';
import { apiClient, isAuthorizedState } from '@/utils/api';
import { createGlobalState } from 'react-global-hooks';
interface AuthWrapperProps {
authRequired: boolean;
children: React.ReactNode | React.ReactNode[];
}
export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) {
const [token, setToken] = useState('');
// start with true, and deauth if needed
const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use();
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState('');
const [isBrowser, setIsBrowser] = useState(false);
const inputRef = useRef<HTMLInputElement>(null);
const isAuthorized = authRequired ? isAuthorizedGlobal : true;
// Set isBrowser to true when component mounts
useEffect(() => {
setIsBrowser(true);
// Get token from localStorage only after component has mounted
const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
setToken(storedToken);
checkAuth();
}, []);
// auto focus on input when not authorized
useEffect(() => {
if (isAuthorized) {
return;
}
setTimeout(() => {
if (inputRef.current) {
inputRef.current.focus();
}
}, 100);
}, [isAuthorized]);
const checkAuth = async () => {
// always get current stored token here to avoid state race conditions
const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
if (!authRequired || isLoading || currentToken === '') {
return;
}
setIsLoading(true);
setError('');
try {
const response = await apiClient.get('/api/auth');
if (response.data.isAuthenticated) {
setIsAuthorized(true);
} else {
setIsAuthorized(false);
setError('Invalid token. Please try again.');
}
} catch (err) {
setIsAuthorized(false);
console.log(err);
setError('Invalid token. Please try again.');
}
setIsLoading(false);
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError('');
if (!token.trim()) {
setError('Please enter your token');
return;
}
if (isBrowser) {
localStorage.setItem('AI_TOOLKIT_AUTH', token);
checkAuth();
}
};
if (isAuthorized) {
return <>{children}</>;
}
return (
<div className="flex min-h-screen bg-gray-900 text-gray-100 absolute top-0 left-0 right-0 bottom-0 scroll-auto">
{/* Left side - decorative or brand area */}
<div className="hidden lg:flex lg:w-1/2 bg-gray-800 flex-col justify-center items-center p-12">
<div className="mb-4">
{/* Replace with your own logo */}
<div className="flex items-center justify-center">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
</div>
</div>
<h1 className="text-4xl mb-6">AI Toolkit</h1>
</div>
{/* Right side - login form */}
<div className="w-full lg:w-1/2 flex flex-col justify-center items-center p-8 sm:p-12">
<div className="w-full max-w-md">
<div className="lg:hidden flex justify-center mb-4">
{/* Mobile logo */}
<div className="flex items-center justify-center">
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
</div>
</div>
<h2 className="text-3xl text-center mb-2 lg:hidden">AI Toolkit</h2>
<form onSubmit={handleSubmit} className="space-y-6">
<div>
<label htmlFor="token" className="block text-sm font-medium text-gray-400 mb-2">
Access Token
</label>
<input
id="token"
name="token"
type="password"
autoComplete="off"
required
value={token}
ref={inputRef}
onChange={e => setToken(e.target.value)}
className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200"
placeholder="Enter your token"
/>
</div>
{error && (
<div className="p-3 bg-red-900/50 border border-red-800 rounded-lg text-red-200 text-sm">{error}</div>
)}
<button
type="submit"
disabled={isLoading}
className="w-full py-3 px-4 bg-blue-600 hover:bg-blue-700 rounded-lg text-white font-medium focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 transition duration-200 flex items-center justify-center"
>
{isLoading ? (
<svg
className="animate-spin h-5 w-5 text-white"
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
>
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
<path
className="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
) : (
'Check Token'
)}
</button>
</form>
</div>
</div>
</div>
);
}

View File

@@ -2,6 +2,7 @@ import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 're
import { FaTrashAlt } from 'react-icons/fa';
import { openConfirm } from './ConfirmModal';
import classNames from 'classnames';
import { apiClient } from '@/utils/api';
interface DatasetImageCardProps {
imageUrl: string;
@@ -27,30 +28,32 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
const isGettingCaption = useRef<boolean>(false);
const fetchCaption = async () => {
try {
if (isGettingCaption.current || isCaptionLoaded) return;
isGettingCaption.current = true;
const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`);
const data = await response.text();
setCaption(data);
setSavedCaption(data);
setIsCaptionLoaded(true);
} catch (error) {
console.error('Error fetching caption:', error);
}
if (isGettingCaption.current || isCaptionLoaded) return;
isGettingCaption.current = true;
apiClient
.get(`/api/caption/${encodeURIComponent(imageUrl)}`)
.then(res => res.data)
.then(data => {
console.log('Caption fetched:', data);
setCaption(data || '');
setSavedCaption(data || '');
setIsCaptionLoaded(true);
})
.catch(error => {
console.error('Error fetching caption:', error);
})
.finally(() => {
isGettingCaption.current = false;
});
};
const saveCaption = () => {
const trimmedCaption = caption.trim();
if (trimmedCaption === savedCaption) return;
fetch('/api/img/caption', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }),
})
.then(res => res.json())
apiClient
.post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption })
.then(res => res.data)
.then(data => {
console.log('Caption saved:', data);
setSavedCaption(trimmedCaption);
@@ -129,16 +132,10 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
type: 'warning',
confirmText: 'Delete',
onConfirm: () => {
fetch('/api/img/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ imgPath: imageUrl }),
})
.then(res => res.json())
.then(data => {
console.log('Image deleted:', data);
apiClient
.post('/api/img/delete', { imgPath: imageUrl })
.then(() => {
console.log('Image deleted:', imageUrl);
onDelete();
})
.catch(error => {

View File

@@ -24,9 +24,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
<div className="flex items-center space-x-2">
<Brain className="w-5 h-5 text-purple-400" />
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
{files.length}
</span>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">{files.length}</span>
</div>
</div>
@@ -52,7 +50,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
return (
<a
key={index}
target='_blank'
target="_blank"
href={`/api/files/${encodeURIComponent(file.path)}`}
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
>
@@ -80,9 +78,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
)}
{['success', 'refreshing'].includes(status) && files.length === 0 && (
<div className="text-center py-4 text-gray-400 text-sm">
No checkpoints available
</div>
<div className="text-center py-4 text-gray-400 text-sm">No checkpoints available</div>
)}
</div>
</div>

View File

@@ -1,7 +1,8 @@
import React, { useState, useEffect, useRef } from 'react';
import React, { useState, useEffect, useRef, useMemo } from 'react';
import { GPUApiResponse } from '@/types';
import Loading from '@/components/Loading';
import GPUWidget from '@/components/GPUWidget';
import { apiClient } from '@/utils/api';
const GpuMonitor: React.FC = () => {
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
@@ -12,27 +13,26 @@ const GpuMonitor: React.FC = () => {
useEffect(() => {
const fetchGpuInfo = async () => {
try {
if (isFetchingGpuRef.current) {
return;
}
isFetchingGpuRef.current = true;
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
setGpuData(data);
setLastUpdated(new Date());
setError(null);
} catch (err) {
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
} finally {
isFetchingGpuRef.current = false;
setLoading(false);
if (isFetchingGpuRef.current) {
return;
}
setLoading(true);
isFetchingGpuRef.current = true;
apiClient
.get('/api/gpu')
.then(res => res.data)
.then(data => {
setGpuData(data);
setLastUpdated(new Date());
setError(null);
})
.catch(err => {
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
})
.finally(() => {
isFetchingGpuRef.current = false;
setLoading(false);
});
};
// Fetch immediately on component mount
@@ -69,46 +69,63 @@ const GpuMonitor: React.FC = () => {
}
};
if (loading) {
return <Loading />;
}
console.log('state', {
loading,
gpuData,
error,
lastUpdated,
});
const content = useMemo(() => {
if (loading && !gpuData) {
return <Loading />;
}
if (error) {
return (
<div className="bg-red-900 border border-red-600 text-red-200 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">Error!</strong>
<span className="block sm:inline"> {error}</span>
</div>
);
}
if (!gpuData) {
return (
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPU data available.</span>
</div>
);
}
if (!gpuData.hasNvidiaSmi) {
return (
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
</div>
);
}
if (gpuData.gpus.length === 0) {
return (
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
</div>
);
}
const gridClass = getGridClasses(gpuData?.gpus?.length || 1);
if (error) {
return (
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">Error!</strong>
<span className="block sm:inline"> {error}</span>
<div className={`grid ${gridClass} gap-3`}>
{gpuData.gpus.map((gpu, idx) => (
<GPUWidget key={idx} gpu={gpu} />
))}
</div>
);
}
if (!gpuData) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPU data available.</span>
</div>
);
}
if (!gpuData.hasNvidiaSmi) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
</div>
);
}
if (gpuData.gpus.length === 0) {
return (
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
</div>
);
}
const gridClass = getGridClasses(gpuData.gpus.length);
}, [loading, gpuData, error]);
return (
<div className="w-full">
@@ -116,12 +133,7 @@ const GpuMonitor: React.FC = () => {
<h1 className="text-md">GPU Monitor</h1>
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
</div>
<div className={`grid ${gridClass} gap-3`}>
{gpuData.gpus.map((gpu, idx) => (
<GPUWidget key={idx} gpu={gpu} />
))}
</div>
{content}
</div>
);
};

View File

@@ -24,9 +24,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<div className="flex items-center space-x-2">
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
#{gpu.index}
</span>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">#{gpu.index}</span>
</div>
</div>
@@ -38,18 +36,14 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
<div>
<p className="text-xs text-gray-400">Temperature</p>
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>
{gpu.temperature}°C
</p>
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>{gpu.temperature}°C</p>
</div>
</div>
<div className="flex items-center space-x-2">
<Fan className="w-4 h-4 text-blue-400" />
<div>
<p className="text-xs text-gray-400">Fan Speed</p>
<p className="text-sm font-medium text-blue-400">
{gpu.fan.speed}%
</p>
<p className="text-sm font-medium text-blue-400">{gpu.fan.speed}%</p>
</div>
</div>
</div>
@@ -99,7 +93,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
<p className="text-xs text-gray-400">Power Draw</p>
<p className="text-sm text-gray-200">
{gpu.power.draw?.toFixed(1)}W
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || ' ? '}W</span>
</p>
</div>
</div>

View File

@@ -45,7 +45,9 @@ export default function JobOverview({ job }: JobOverviewProps) {
{/* Job Information Panel */}
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
<h2 className="text-gray-100">
<Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}
</h2>
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
</div>
@@ -85,7 +87,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
<Gauge className="w-5 h-5 text-green-400" />
<div>
<p className="text-xs text-gray-400">Speed</p>
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
<p className="text-sm font-medium text-gray-200">{job.speed_string == '' ? '?' : job.speed_string}</p>
</div>
</div>
</div>

View File

@@ -1,6 +1,6 @@
export default function Loading() {
return (
<div className="flex justify-center items-center h-64">
<div className="flex justify-center items-center h-42">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
</div>
);

View File

@@ -11,7 +11,14 @@ interface SampleImageCardProps {
onDelete?: () => void;
}
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => {
const SampleImageCard: React.FC<SampleImageCardProps> = ({
imageUrl,
alt,
numSamples,
sampleImages,
children,
className = '',
}) => {
const cardRef = useRef<HTMLDivElement>(null);
const [isVisible, setIsVisible] = useState<boolean>(false);
const [loaded, setLoaded] = useState<boolean>(false);

View File

@@ -1,9 +1,10 @@
import Link from 'next/link';
import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react';
const Sidebar = () => {
const navigation = [
{ name: 'Dashboard', href: '/dashboard', icon: Home },
{ name: 'New Job', href: '/jobs/new', icon: Plus },
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
{ name: 'Datasets', href: '/datasets', icon: Images },
{ name: 'Settings', href: '/settings', icon: Settings },
@@ -33,7 +34,7 @@ const Sidebar = () => {
</ul>
</nav>
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
<div className='min-w-[26px] min-h-[26px]'>
<div className="min-w-[26px] min-h-[26px]">
<svg
viewBox="0 0 512 512"
xmlns="http://www.w3.org/2000/svg"

View File

@@ -152,14 +152,14 @@ export const Checkbox = (props: CheckboxProps) => {
className={classNames(
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
checked ? 'bg-blue-600' : 'bg-gray-700',
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80',
)}
>
<span className="sr-only">Toggle {label}</span>
<span
className={classNames(
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
checked ? 'translate-x-5' : 'translate-x-0'
checked ? 'translate-x-5' : 'translate-x-0',
)}
/>
</button>
@@ -168,7 +168,7 @@ export const Checkbox = (props: CheckboxProps) => {
htmlFor={id}
className={classNames(
'text-sm font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300'
disabled ? 'text-gray-500' : 'text-gray-300',
)}
>
{label}

View File

@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
import { apiClient } from '@/utils/api';
export default function useDatasetList() {
const [datasets, setDatasets] = useState<string[]>([]);
@@ -8,8 +9,9 @@ export default function useDatasetList() {
const refreshDatasets = () => {
setStatus('loading');
fetch('/api/datasets/list')
.then(res => res.json())
apiClient
.get('/api/datasets/list')
.then(res => res.data)
.then(data => {
console.log('Datasets:', data);
// sort

View File

@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState, useRef } from 'react';
import { apiClient } from '@/utils/api';
interface FileObject {
path: string;
@@ -18,8 +19,9 @@ export default function useFilesList(jobID: string, reloadInterval: null | numbe
loadStatus = 'refreshing';
}
setStatus(loadStatus);
fetch(`/api/jobs/${jobID}/files`)
.then(res => res.json())
apiClient
.get(`/api/jobs/${jobID}/files`)
.then(res => res.data)
.then(data => {
console.log('Fetched files:', data);
if (data.files) {

View File

@@ -2,6 +2,7 @@
import { GPUApiResponse, GpuInfo } from '@/types';
import { useEffect, useState } from 'react';
import { apiClient } from '@/utils/api';
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
@@ -11,18 +12,11 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
const fetchGpuInfo = async () => {
setStatus('loading');
try {
const response = await fetch('/api/gpu');
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data: GPUApiResponse = await response.json();
const data: GPUApiResponse = await apiClient.get('/api/gpu').then(res => res.data);
let gpus = data.gpus.sort((a, b) => a.index - b.index);
if (gpuIds) {
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
}
setGpuList(gpus);
setStatus('success');
} catch (err) {

View File

@@ -2,6 +2,7 @@
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
import { apiClient } from '@/utils/api';
export default function useJob(jobID: string, reloadInterval: null | number = null) {
const [job, setJob] = useState<Job | null>(null);
@@ -9,8 +10,9 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
const refreshJob = () => {
setStatus('loading');
fetch(`/api/jobs?id=${jobID}`)
.then(res => res.json())
apiClient
.get(`/api/jobs?id=${jobID}`)
.then(res => res.data)
.then(data => {
console.log('Job:', data);
setJob(data);
@@ -32,7 +34,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
return () => {
clearInterval(interval);
}
};
}
}, [jobID]);

View File

@@ -2,6 +2,7 @@
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
import { apiClient } from '@/utils/api';
export default function useJobsList(onlyActive = false) {
const [jobs, setJobs] = useState<Job[]>([]);
@@ -9,8 +10,9 @@ export default function useJobsList(onlyActive = false) {
const refreshJobs = () => {
setStatus('loading');
fetch('/api/jobs')
.then(res => res.json())
apiClient
.get('/api/jobs')
.then(res => res.data)
.then(data => {
console.log('Jobs:', data);
if (data.error) {

View File

@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
import { apiClient } from '@/utils/api';
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
const [sampleImages, setSampleImages] = useState<string[]>([]);
@@ -8,9 +9,11 @@ export default function useSampleImages(jobID: string, reloadInterval: null | nu
const refreshSampleImages = () => {
setStatus('loading');
fetch(`/api/jobs/${jobID}/samples`)
.then(res => res.json())
apiClient
.get(`/api/jobs/${jobID}/samples`)
.then(res => res.data)
.then(data => {
console.log('Fetched sample images:', data);
if (data.samples) {
setSampleImages(data.samples);
}

View File

@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
import { apiClient } from '@/utils/api';
export interface Settings {
HF_TOKEN: string;
@@ -16,10 +17,11 @@ export default function useSettings() {
});
const [isSettingsLoaded, setIsLoaded] = useState(false);
useEffect(() => {
// Fetch current settings
fetch('/api/settings')
.then(res => res.json())
apiClient
.get('/api/settings')
.then(res => res.data)
.then(data => {
console.log('Settings:', data);
setSettings({
HF_TOKEN: data.HF_TOKEN || '',
TRAINING_FOLDER: data.TRAINING_FOLDER || '',

49
ui/src/middleware.ts Normal file
View File

@@ -0,0 +1,49 @@
// middleware.ts (at the root of your project)
import { NextResponse } from 'next/server';
import type { NextRequest } from 'next/server';
// if route starts with these, approve
const publicRoutes = ['/api/img/', '/api/files/'];
export function middleware(request: NextRequest) {
// check env var for AI_TOOLKIT_AUTH, if not set, approve all requests
// if it is set make sure bearer token matches
const tokenToUse = process.env.AI_TOOLKIT_AUTH || null;
if (!tokenToUse) {
return NextResponse.next();
}
// Get the token from the headers
const token = request.headers.get('Authorization')?.split(' ')[1];
// allow public routes to pass through
if (publicRoutes.some(route => request.nextUrl.pathname.startsWith(route))) {
return NextResponse.next();
}
// Check if the route should be protected
// This will apply to all API routes that start with /api/
if (request.nextUrl.pathname.startsWith('/api/')) {
if (!token || token !== tokenToUse) {
// Return a JSON response with 401 Unauthorized
return new NextResponse(JSON.stringify({ error: 'Unauthorized' }), {
status: 401,
headers: { 'Content-Type': 'application/json' },
});
}
// For authorized users, continue
return NextResponse.next();
}
// For non-API routes, just continue
return NextResponse.next();
}
// Configure which paths this middleware will run on
export const config = {
matcher: [
// Apply to all API routes
'/api/:path*',
],
};

31
ui/src/utils/api.ts Normal file
View File

@@ -0,0 +1,31 @@
import axios from 'axios';
import { createGlobalState } from 'react-global-hooks';
export const isAuthorizedState = createGlobalState(false);
export const apiClient = axios.create();
// Add a request interceptor to add token from localStorage
apiClient.interceptors.request.use(config => {
const token = localStorage.getItem('AI_TOOLKIT_AUTH');
if (token) {
config.headers['Authorization'] = `Bearer ${token}`;
}
return config;
});
// Add a response interceptor to handle 401 errors
apiClient.interceptors.response.use(
response => response, // Return successful responses as-is
error => {
// Check if the error is a 401 Unauthorized
if (error.response && error.response.status === 401) {
// Clear the auth token from localStorage
localStorage.removeItem('AI_TOOLKIT_AUTH');
isAuthorizedState.set(false);
}
// Reject the promise with the error so calling code can still catch it
return Promise.reject(error);
},
);

View File

@@ -2,3 +2,4 @@ export const objectCopy = <T>(obj: T): T => {
return JSON.parse(JSON.stringify(obj)) as T;
};
export const wait = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));

View File

@@ -79,7 +79,7 @@ export function useNestedState<T>(initialState: T): [T, (value: any, path?: stri
const setValue = React.useCallback((value: any, path?: string) => {
if (path === undefined) {
setState(value);
return
return;
}
setState(prevState => setNestedValue(prevState, value, path));
}, []);

View File

@@ -1,10 +1,12 @@
import { JobConfig } from '@/types';
import { Job } from '@prisma/client';
import { apiClient } from '@/utils/api';
export const startJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/start`)
.then(res => res.json())
apiClient
.get(`/api/jobs/${jobID}/start`)
.then(res => res.data)
.then(data => {
console.log('Job started:', data);
resolve();
@@ -18,8 +20,9 @@ export const startJob = (jobID: string) => {
export const stopJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/stop`)
.then(res => res.json())
apiClient
.get(`/api/jobs/${jobID}/stop`)
.then(res => res.data)
.then(data => {
console.log('Job stopped:', data);
resolve();
@@ -33,8 +36,9 @@ export const stopJob = (jobID: string) => {
export const deleteJob = (jobID: string) => {
return new Promise<void>((resolve, reject) => {
fetch(`/api/jobs/${jobID}/delete`)
.then(res => res.json())
apiClient
.get(`/api/jobs/${jobID}/delete`)
.then(res => res.data)
.then(data => {
console.log('Job deleted:', data);
resolve();
@@ -67,9 +71,9 @@ export const getAvaliableJobActions = (job: Job) => {
export const getNumberOfSamples = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].sample?.prompts?.length || 0;
}
};
export const getTotalSteps = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].train.steps;
}
};

View File

@@ -1,26 +1,26 @@
import type { Config } from "tailwindcss";
import type { Config } from 'tailwindcss';
const config: Config = {
content: [
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
'./src/pages/**/*.{js,ts,jsx,tsx,mdx}',
'./src/components/**/*.{js,ts,jsx,tsx,mdx}',
'./src/app/**/*.{js,ts,jsx,tsx,mdx}',
],
darkMode: "class",
darkMode: 'class',
theme: {
extend: {
colors: {
gray: {
950: "#0a0a0a",
900: "#171717",
800: "#262626",
700: "#404040",
600: "#525252",
500: "#737373",
400: "#a3a3a3",
300: "#d4d4d4",
200: "#e5e5e5",
100: "#f5f5f5",
950: '#0a0a0a',
900: '#171717',
800: '#262626',
700: '#404040',
600: '#525252',
500: '#737373',
400: '#a3a3a3',
300: '#d4d4d4',
200: '#e5e5e5',
100: '#f5f5f5',
},
},
},