mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a way to secure the UI. Plus various bug fixes and quality of life updates
This commit is contained in:
@@ -15,7 +15,8 @@ class UITrainer(SDTrainer):
|
||||
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||
if not os.path.exists(self.sqlite_db_path):
|
||||
raise Exception(f"SQLite database not found at {self.sqlite_db_path}")
|
||||
raise Exception(
|
||||
f"SQLite database not found at {self.sqlite_db_path}")
|
||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||
self.job_id = self.job_id.strip() if self.job_id is not None else None
|
||||
@@ -147,6 +148,8 @@ class UITrainer(SDTrainer):
|
||||
|
||||
try:
|
||||
await asyncio.gather(*self._async_tasks)
|
||||
except Exception as e:
|
||||
print(f"Error waiting for async operations: {e}")
|
||||
finally:
|
||||
# Clear the task list after completion
|
||||
self._async_tasks.clear()
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start --port 8675",
|
||||
"build_and_start": "npm install && npm run update_db && npm run build && npm run start",
|
||||
"lint": "next lint",
|
||||
"update_db": "npx prisma generate && npx prisma db push",
|
||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||
|
||||
6
ui/src/app/api/auth/route.ts
Normal file
6
ui/src/app/api/auth/route.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
|
||||
export async function GET() {
|
||||
// if this gets hit, auth has already been verified
|
||||
return NextResponse.json({ isAuthenticated: true });
|
||||
}
|
||||
@@ -19,4 +19,4 @@ export async function POST(request: Request) {
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,31 +8,25 @@ export async function POST(request: Request) {
|
||||
const body = await request.json();
|
||||
const { datasetName } = body;
|
||||
const datasetFolder = path.join(datasetsPath, datasetName);
|
||||
|
||||
|
||||
try {
|
||||
// Check if folder exists
|
||||
if (!fs.existsSync(datasetFolder)) {
|
||||
return NextResponse.json(
|
||||
{ error: `Folder '${datasetName}' not found` },
|
||||
{ status: 404 }
|
||||
);
|
||||
return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
|
||||
}
|
||||
|
||||
// Find all images recursively
|
||||
const imageFiles = findImagesRecursively(datasetFolder);
|
||||
|
||||
|
||||
// Format response
|
||||
const result = imageFiles.map(imgPath => ({
|
||||
img_path: imgPath
|
||||
img_path: imgPath,
|
||||
}));
|
||||
|
||||
|
||||
return NextResponse.json({ images: result });
|
||||
} catch (error) {
|
||||
console.error('Error finding images:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to process request' },
|
||||
{ status: 500 }
|
||||
);
|
||||
return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,13 +38,13 @@ export async function POST(request: Request) {
|
||||
function findImagesRecursively(dir: string): string[] {
|
||||
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp'];
|
||||
let results: string[] = [];
|
||||
|
||||
|
||||
const items = fs.readdirSync(dir);
|
||||
|
||||
|
||||
for (const item of items) {
|
||||
const itemPath = path.join(dir, item);
|
||||
const stat = fs.statSync(itemPath);
|
||||
|
||||
|
||||
if (stat.isDirectory()) {
|
||||
// If it's a directory, recursively search it
|
||||
results = results.concat(findImagesRecursively(itemPath));
|
||||
@@ -62,6 +56,6 @@ function findImagesRecursively(dir: string): string[] {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return results;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,8 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
||||
const allowedDirs = [datasetRoot, trainingRoot];
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
||||
const isAllowed =
|
||||
allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
||||
@@ -62,7 +63,7 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Cache-Control': 'public, max-age=86400',
|
||||
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
||||
'X-Content-Type-Options': 'nosniff'
|
||||
'X-Content-Type-Options': 'nosniff',
|
||||
};
|
||||
|
||||
if (range) {
|
||||
@@ -70,12 +71,12 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
||||
const parts = range.replace(/bytes=/, '').split('-');
|
||||
const start = parseInt(parts[0], 10);
|
||||
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
||||
const chunkSize = (end - start) + 1;
|
||||
const chunkSize = end - start + 1;
|
||||
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
start,
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
start,
|
||||
end,
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
highWaterMark: 64 * 1024, // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
@@ -83,19 +84,19 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
||||
'Content-Length': String(chunkSize)
|
||||
'Content-Length': String(chunkSize),
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// For full file download, read directly without streaming wrapper
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
highWaterMark: 64 * 1024, // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Length': String(stat.size)
|
||||
'Content-Length': String(stat.size),
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -103,4 +104,4 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
||||
console.error('Error serving file:', error);
|
||||
return new NextResponse('Internal Server Error', { status: 500 });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ export async function GET() {
|
||||
// Get platform
|
||||
const platform = os.platform();
|
||||
const isWindows = platform === 'win32';
|
||||
|
||||
|
||||
// Check if nvidia-smi is available
|
||||
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
||||
|
||||
@@ -61,8 +61,9 @@ async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
|
||||
|
||||
async function getGpuStats(isWindows: boolean) {
|
||||
// Command is the same for both platforms, but the path might be different
|
||||
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
||||
|
||||
const command =
|
||||
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
||||
|
||||
// Execute command
|
||||
const { stdout } = await execAsync(command);
|
||||
|
||||
@@ -117,4 +118,4 @@ async function getGpuStats(isWindows: boolean) {
|
||||
});
|
||||
|
||||
return gpus;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ export async function POST(request: Request) {
|
||||
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
|
||||
}
|
||||
|
||||
|
||||
// check for caption
|
||||
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
||||
// save caption to file
|
||||
|
||||
@@ -103,7 +103,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
cwd: TOOLKIT_ROOT,
|
||||
windowsHide: false,
|
||||
});
|
||||
|
||||
|
||||
subprocess.unref();
|
||||
} else {
|
||||
// For non-Windows platforms, use your original approach
|
||||
@@ -116,7 +116,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
|
||||
subprocess.unref();
|
||||
}
|
||||
// const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
|
||||
@@ -6,7 +6,6 @@ const prisma = new PrismaClient();
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const id = searchParams.get('id');
|
||||
console.log('ID:', id);
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
@@ -19,7 +18,6 @@ export async function GET(request: Request) {
|
||||
const jobs = await prisma.job.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
console.log('Jobs:', jobs);
|
||||
return NextResponse.json({ jobs: jobs });
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
||||
import {flushCache} from '@/server/settings';
|
||||
import { flushCache } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import DatasetImageCard from '@/components/DatasetImageCard';
|
||||
import { Button } from '@headlessui/react';
|
||||
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
|
||||
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
|
||||
@@ -15,15 +16,11 @@ export default function DatasetPage({ params }: { params: { datasetName: string
|
||||
|
||||
const refreshImageList = (dbName: string) => {
|
||||
setStatus('loading');
|
||||
fetch('/api/datasets/listImages', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ datasetName: dbName }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Fetching images for dataset:', dbName);
|
||||
apiClient
|
||||
.post('/api/datasets/listImages', { datasetName: dbName })
|
||||
.then((res: any) => {
|
||||
const data = res.data;
|
||||
console.log('Images:', data.images);
|
||||
// sort
|
||||
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
|
||||
|
||||
@@ -10,6 +10,7 @@ import { FaRegTrashAlt } from 'react-icons/fa';
|
||||
import { openConfirm } from '@/components/ConfirmModal';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function Datasets() {
|
||||
const { datasets, status, refreshDatasets } = useDatasetList();
|
||||
@@ -54,16 +55,10 @@ export default function Datasets() {
|
||||
type: 'warning',
|
||||
confirmText: 'Delete',
|
||||
onConfirm: () => {
|
||||
fetch('/api/datasets/delete', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ name: datasetName }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Dataset deleted:', data);
|
||||
apiClient
|
||||
.post('/api/datasets/delete', { name: datasetName })
|
||||
.then(() => {
|
||||
console.log('Dataset deleted:', datasetName);
|
||||
refreshDatasets();
|
||||
})
|
||||
.catch(error => {
|
||||
@@ -76,14 +71,7 @@ export default function Datasets() {
|
||||
const handleCreateDataset = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const response = await fetch('/api/datasets/create', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ name: newDatasetName }),
|
||||
});
|
||||
const data = await response.json();
|
||||
const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
|
||||
console.log('New dataset created:', data);
|
||||
refreshDatasets();
|
||||
setNewDatasetName('');
|
||||
|
||||
@@ -33,11 +33,7 @@ const yamlConfig: YAML.DocumentOptions &
|
||||
directives: true,
|
||||
};
|
||||
|
||||
export default function AdvancedJob({
|
||||
jobConfig,
|
||||
setJobConfig,
|
||||
settings,
|
||||
}: Props) {
|
||||
export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
|
||||
const [editorValue, setEditorValue] = useState<string>('');
|
||||
const lastJobConfigUpdateStringRef = useRef('');
|
||||
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
|
||||
|
||||
@@ -30,7 +30,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
linear: 32,
|
||||
linear_alpha: 32,
|
||||
lokr_full_rank: true,
|
||||
lokr_factor: -1
|
||||
lokr_factor: -1,
|
||||
},
|
||||
save: {
|
||||
dtype: 'bf16',
|
||||
@@ -39,9 +39,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
save_format: 'diffusers',
|
||||
push_to_hub: false,
|
||||
},
|
||||
datasets: [
|
||||
defaultDatasetConfig
|
||||
],
|
||||
datasets: [defaultDatasetConfig],
|
||||
train: {
|
||||
batch_size: 1,
|
||||
bypass_guidance_embedding: true,
|
||||
@@ -55,7 +53,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
timestep_type: 'sigmoid',
|
||||
content_or_style: 'balanced',
|
||||
optimizer_params: {
|
||||
weight_decay: 1e-4
|
||||
weight_decay: 1e-4,
|
||||
},
|
||||
unload_text_encoder: false,
|
||||
lr: 0.0001,
|
||||
@@ -66,8 +64,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
dtype: 'bf16',
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person'
|
||||
|
||||
diff_output_preservation_class: 'person',
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -12,7 +12,7 @@ export const modelArchs = [
|
||||
{ name: 'flux', label: 'Flux.1' },
|
||||
{ name: 'wan21', label: 'Wan 2.1' },
|
||||
{ name: 'lumina2', label: 'Lumina2' },
|
||||
]
|
||||
];
|
||||
|
||||
export const isVideoModelFromArch = (arch: string) => {
|
||||
const videoArches = ['wan21'];
|
||||
|
||||
@@ -19,6 +19,7 @@ import { Button } from '@headlessui/react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import SimpleJob from './SimpleJob';
|
||||
import AdvancedJob from './AdvancedJob';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
const isDev = process.env.NODE_ENV === 'development';
|
||||
|
||||
@@ -56,12 +57,13 @@ export default function TrainingForm() {
|
||||
|
||||
useEffect(() => {
|
||||
if (runId) {
|
||||
fetch(`/api/jobs?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs?id=${runId}`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Training:', data);
|
||||
setGpuIDs(data.gpu_ids);
|
||||
setJobConfig(JSON.parse(data.job_config));
|
||||
// setJobConfig(data.name, 'config.name');
|
||||
})
|
||||
.catch(error => console.error('Error fetching training:', error));
|
||||
}
|
||||
@@ -85,33 +87,30 @@ export default function TrainingForm() {
|
||||
if (status === 'saving') return;
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/jobs', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
id: runId,
|
||||
name: jobConfig.config.name,
|
||||
gpu_ids: gpuIDs,
|
||||
job_config: jobConfig,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save training');
|
||||
|
||||
setStatus('success');
|
||||
if (!runId) {
|
||||
const data = await response.json();
|
||||
router.push(`/jobs/${data.id}`);
|
||||
}
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving training:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
apiClient
|
||||
.post('/api/jobs', {
|
||||
id: runId,
|
||||
name: jobConfig.config.name,
|
||||
gpu_ids: gpuIDs,
|
||||
job_config: jobConfig,
|
||||
})
|
||||
.then(res => {
|
||||
setStatus('success');
|
||||
if (runId) {
|
||||
router.push(`/jobs/${runId}`);
|
||||
} else {
|
||||
router.push(`/jobs/${res.data.id}`);
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error saving training:', error);
|
||||
setStatus('error');
|
||||
})
|
||||
.finally(() =>
|
||||
setTimeout(() => {
|
||||
setStatus('idle');
|
||||
}, 2000),
|
||||
);
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
|
||||
@@ -13,10 +13,7 @@ export default function Dashboard() {
|
||||
</div>
|
||||
<div className="flex-1"></div>
|
||||
<div>
|
||||
<Link
|
||||
href="/jobs/new"
|
||||
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
||||
>
|
||||
<Link href="/jobs/new" className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md">
|
||||
New Training Job
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
@@ -6,6 +6,7 @@ import { ThemeProvider } from '@/components/ThemeProvider';
|
||||
import ConfirmModal from '@/components/ConfirmModal';
|
||||
import SampleImageModal from '@/components/SampleImageModal';
|
||||
import { Suspense } from 'react';
|
||||
import AuthWrapper from '@/components/AuthWrapper';
|
||||
|
||||
const inter = Inter({ subsets: ['latin'] });
|
||||
|
||||
@@ -15,6 +16,9 @@ export const metadata: Metadata = {
|
||||
};
|
||||
|
||||
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
||||
// Check if the AI_TOOLKIT_AUTH environment variable is set
|
||||
const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false;
|
||||
|
||||
return (
|
||||
<html lang="en" className="dark">
|
||||
<head>
|
||||
@@ -22,13 +26,14 @@ export default function RootLayout({ children }: { children: React.ReactNode })
|
||||
</head>
|
||||
<body className={inter.className}>
|
||||
<ThemeProvider>
|
||||
<div className="flex h-screen bg-gray-950">
|
||||
<Sidebar />
|
||||
|
||||
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
||||
<Suspense>{children}</Suspense>
|
||||
</main>
|
||||
</div>
|
||||
<AuthWrapper authRequired={authRequired}>
|
||||
<div className="flex h-screen bg-gray-950">
|
||||
<Sidebar />
|
||||
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
||||
<Suspense>{children}</Suspense>
|
||||
</main>
|
||||
</div>
|
||||
</AuthWrapper>
|
||||
</ThemeProvider>
|
||||
<ConfirmModal />
|
||||
<SampleImageModal />
|
||||
|
||||
@@ -2,4 +2,4 @@ import { redirect } from 'next/navigation';
|
||||
|
||||
export default function Home() {
|
||||
redirect('/dashboard');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import useSettings from '@/hooks/useSettings';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function Settings() {
|
||||
const { settings, setSettings } = useSettings();
|
||||
@@ -12,24 +13,18 @@ export default function Settings() {
|
||||
e.preventDefault();
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/settings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
apiClient
|
||||
.post('/api/settings', settings)
|
||||
.then(() => {
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error saving settings:', error);
|
||||
setStatus('error');
|
||||
})
|
||||
.finally(() => {
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save settings');
|
||||
|
||||
setStatus('success');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving settings:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/re
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import axios from 'axios';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export interface AddImagesModalState {
|
||||
datasetName: string;
|
||||
@@ -15,7 +15,7 @@ export const addImagesModalState = createGlobalState<AddImagesModalState | null>
|
||||
|
||||
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
|
||||
addImagesModalState.set({ datasetName, onComplete });
|
||||
}
|
||||
};
|
||||
|
||||
export default function AddImagesModal() {
|
||||
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
|
||||
@@ -36,46 +36,49 @@ export default function AddImagesModal() {
|
||||
}
|
||||
};
|
||||
|
||||
const onDrop = useCallback(async (acceptedFiles: File[]) => {
|
||||
if (acceptedFiles.length === 0) return;
|
||||
const onDrop = useCallback(
|
||||
async (acceptedFiles: File[]) => {
|
||||
if (acceptedFiles.length === 0) return;
|
||||
|
||||
setIsUploading(true);
|
||||
setUploadProgress(0);
|
||||
|
||||
const formData = new FormData();
|
||||
acceptedFiles.forEach(file => {
|
||||
formData.append('files', file);
|
||||
});
|
||||
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
||||
|
||||
try {
|
||||
await axios.post(`/api/datasets/upload`, formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
onUploadProgress: (progressEvent) => {
|
||||
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
||||
setUploadProgress(percentCompleted);
|
||||
},
|
||||
timeout: 0, // Disable timeout
|
||||
});
|
||||
|
||||
onDone();
|
||||
} catch (error) {
|
||||
console.error('Upload failed:', error);
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
setIsUploading(true);
|
||||
setUploadProgress(0);
|
||||
}
|
||||
}, [addImagesModalInfo]);
|
||||
|
||||
const formData = new FormData();
|
||||
acceptedFiles.forEach(file => {
|
||||
formData.append('files', file);
|
||||
});
|
||||
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
||||
|
||||
try {
|
||||
await apiClient.post(`/api/datasets/upload`, formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
onUploadProgress: progressEvent => {
|
||||
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
||||
setUploadProgress(percentCompleted);
|
||||
},
|
||||
timeout: 0, // Disable timeout
|
||||
});
|
||||
|
||||
onDone();
|
||||
} catch (error) {
|
||||
console.error('Upload failed:', error);
|
||||
} finally {
|
||||
setIsUploading(false);
|
||||
setUploadProgress(0);
|
||||
}
|
||||
},
|
||||
[addImagesModalInfo],
|
||||
);
|
||||
|
||||
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
||||
onDrop,
|
||||
accept: {
|
||||
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
|
||||
'text/*': ['.txt']
|
||||
'text/*': ['.txt'],
|
||||
},
|
||||
multiple: true
|
||||
multiple: true,
|
||||
});
|
||||
|
||||
return (
|
||||
@@ -105,22 +108,15 @@ export default function AddImagesModal() {
|
||||
<input {...getInputProps()} />
|
||||
<FaUpload className="size-8 mb-3 text-gray-400" />
|
||||
<p className="text-sm text-gray-200 text-center">
|
||||
{isDragActive
|
||||
? 'Drop the files here...'
|
||||
: 'Drag & drop files here, or click to select files'}
|
||||
{isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}
|
||||
</p>
|
||||
</div>
|
||||
{isUploading && (
|
||||
<div className="mt-4">
|
||||
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
||||
<div
|
||||
className="bg-blue-600 h-2.5 rounded-full"
|
||||
style={{ width: `${uploadProgress}%` }}
|
||||
></div>
|
||||
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
|
||||
</div>
|
||||
<p className="text-sm text-gray-300 mt-2 text-center">
|
||||
Uploading... {uploadProgress}%
|
||||
</p>
|
||||
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -152,4 +148,4 @@ export default function AddImagesModal() {
|
||||
</div>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
163
ui/src/components/AuthWrapper.tsx
Normal file
163
ui/src/components/AuthWrapper.tsx
Normal file
@@ -0,0 +1,163 @@
|
||||
'use client';
|
||||
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { apiClient, isAuthorizedState } from '@/utils/api';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
|
||||
interface AuthWrapperProps {
|
||||
authRequired: boolean;
|
||||
children: React.ReactNode | React.ReactNode[];
|
||||
}
|
||||
|
||||
export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) {
|
||||
const [token, setToken] = useState('');
|
||||
// start with true, and deauth if needed
|
||||
const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState('');
|
||||
const [isBrowser, setIsBrowser] = useState(false);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const isAuthorized = authRequired ? isAuthorizedGlobal : true;
|
||||
|
||||
// Set isBrowser to true when component mounts
|
||||
useEffect(() => {
|
||||
setIsBrowser(true);
|
||||
// Get token from localStorage only after component has mounted
|
||||
const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
||||
setToken(storedToken);
|
||||
checkAuth();
|
||||
}, []);
|
||||
|
||||
// auto focus on input when not authorized
|
||||
useEffect(() => {
|
||||
if (isAuthorized) {
|
||||
return;
|
||||
}
|
||||
setTimeout(() => {
|
||||
if (inputRef.current) {
|
||||
inputRef.current.focus();
|
||||
}
|
||||
}, 100);
|
||||
}, [isAuthorized]);
|
||||
|
||||
const checkAuth = async () => {
|
||||
// always get current stored token here to avoid state race conditions
|
||||
const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
||||
if (!authRequired || isLoading || currentToken === '') {
|
||||
return;
|
||||
}
|
||||
setIsLoading(true);
|
||||
setError('');
|
||||
try {
|
||||
const response = await apiClient.get('/api/auth');
|
||||
if (response.data.isAuthenticated) {
|
||||
setIsAuthorized(true);
|
||||
} else {
|
||||
setIsAuthorized(false);
|
||||
setError('Invalid token. Please try again.');
|
||||
}
|
||||
} catch (err) {
|
||||
setIsAuthorized(false);
|
||||
console.log(err);
|
||||
setError('Invalid token. Please try again.');
|
||||
}
|
||||
setIsLoading(false);
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError('');
|
||||
|
||||
if (!token.trim()) {
|
||||
setError('Please enter your token');
|
||||
return;
|
||||
}
|
||||
|
||||
if (isBrowser) {
|
||||
localStorage.setItem('AI_TOOLKIT_AUTH', token);
|
||||
checkAuth();
|
||||
}
|
||||
};
|
||||
|
||||
if (isAuthorized) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-screen bg-gray-900 text-gray-100 absolute top-0 left-0 right-0 bottom-0 scroll-auto">
|
||||
{/* Left side - decorative or brand area */}
|
||||
<div className="hidden lg:flex lg:w-1/2 bg-gray-800 flex-col justify-center items-center p-12">
|
||||
<div className="mb-4">
|
||||
{/* Replace with your own logo */}
|
||||
<div className="flex items-center justify-center">
|
||||
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
||||
</div>
|
||||
</div>
|
||||
<h1 className="text-4xl mb-6">AI Toolkit</h1>
|
||||
</div>
|
||||
|
||||
{/* Right side - login form */}
|
||||
<div className="w-full lg:w-1/2 flex flex-col justify-center items-center p-8 sm:p-12">
|
||||
<div className="w-full max-w-md">
|
||||
<div className="lg:hidden flex justify-center mb-4">
|
||||
{/* Mobile logo */}
|
||||
<div className="flex items-center justify-center">
|
||||
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2 className="text-3xl text-center mb-2 lg:hidden">AI Toolkit</h2>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div>
|
||||
<label htmlFor="token" className="block text-sm font-medium text-gray-400 mb-2">
|
||||
Access Token
|
||||
</label>
|
||||
<input
|
||||
id="token"
|
||||
name="token"
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
required
|
||||
value={token}
|
||||
ref={inputRef}
|
||||
onChange={e => setToken(e.target.value)}
|
||||
className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200"
|
||||
placeholder="Enter your token"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="p-3 bg-red-900/50 border border-red-800 rounded-lg text-red-200 text-sm">{error}</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isLoading}
|
||||
className="w-full py-3 px-4 bg-blue-600 hover:bg-blue-700 rounded-lg text-white font-medium focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 transition duration-200 flex items-center justify-center"
|
||||
>
|
||||
{isLoading ? (
|
||||
<svg
|
||||
className="animate-spin h-5 w-5 text-white"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
||||
<path
|
||||
className="opacity-75"
|
||||
fill="currentColor"
|
||||
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||
></path>
|
||||
</svg>
|
||||
) : (
|
||||
'Check Token'
|
||||
)}
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -12,4 +12,4 @@ const Card: React.FC<CardProps> = ({ title, children }) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default Card;
|
||||
export default Card;
|
||||
|
||||
@@ -2,6 +2,7 @@ import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 're
|
||||
import { FaTrashAlt } from 'react-icons/fa';
|
||||
import { openConfirm } from './ConfirmModal';
|
||||
import classNames from 'classnames';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
interface DatasetImageCardProps {
|
||||
imageUrl: string;
|
||||
@@ -27,30 +28,32 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
|
||||
const isGettingCaption = useRef<boolean>(false);
|
||||
|
||||
const fetchCaption = async () => {
|
||||
try {
|
||||
if (isGettingCaption.current || isCaptionLoaded) return;
|
||||
isGettingCaption.current = true;
|
||||
const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`);
|
||||
const data = await response.text();
|
||||
setCaption(data);
|
||||
setSavedCaption(data);
|
||||
setIsCaptionLoaded(true);
|
||||
} catch (error) {
|
||||
console.error('Error fetching caption:', error);
|
||||
}
|
||||
if (isGettingCaption.current || isCaptionLoaded) return;
|
||||
isGettingCaption.current = true;
|
||||
apiClient
|
||||
.get(`/api/caption/${encodeURIComponent(imageUrl)}`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Caption fetched:', data);
|
||||
|
||||
setCaption(data || '');
|
||||
setSavedCaption(data || '');
|
||||
setIsCaptionLoaded(true);
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching caption:', error);
|
||||
})
|
||||
.finally(() => {
|
||||
isGettingCaption.current = false;
|
||||
});
|
||||
};
|
||||
|
||||
const saveCaption = () => {
|
||||
const trimmedCaption = caption.trim();
|
||||
if (trimmedCaption === savedCaption) return;
|
||||
fetch('/api/img/caption', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption })
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Caption saved:', data);
|
||||
setSavedCaption(trimmedCaption);
|
||||
@@ -129,16 +132,10 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
|
||||
type: 'warning',
|
||||
confirmText: 'Delete',
|
||||
onConfirm: () => {
|
||||
fetch('/api/img/delete', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ imgPath: imageUrl }),
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Image deleted:', data);
|
||||
apiClient
|
||||
.post('/api/img/delete', { imgPath: imageUrl })
|
||||
.then(() => {
|
||||
console.log('Image deleted:', imageUrl);
|
||||
onDelete();
|
||||
})
|
||||
.catch(error => {
|
||||
|
||||
@@ -24,9 +24,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
||||
<div className="flex items-center space-x-2">
|
||||
<Brain className="w-5 h-5 text-purple-400" />
|
||||
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
||||
{files.length}
|
||||
</span>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">{files.length}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -50,9 +48,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
||||
const fileName = file.path.split('/').pop() || '';
|
||||
const nameWithoutExt = fileName.replace('.safetensors', '');
|
||||
return (
|
||||
<a
|
||||
<a
|
||||
key={index}
|
||||
target='_blank'
|
||||
target="_blank"
|
||||
href={`/api/files/${encodeURIComponent(file.path)}`}
|
||||
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
|
||||
>
|
||||
@@ -80,11 +78,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
||||
)}
|
||||
|
||||
{['success', 'refreshing'].includes(status) && files.length === 0 && (
|
||||
<div className="text-center py-4 text-gray-400 text-sm">
|
||||
No checkpoints available
|
||||
</div>
|
||||
<div className="text-center py-4 text-gray-400 text-sm">No checkpoints available</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import React, { useState, useEffect, useRef, useMemo } from 'react';
|
||||
import { GPUApiResponse } from '@/types';
|
||||
import Loading from '@/components/Loading';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
const GpuMonitor: React.FC = () => {
|
||||
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
|
||||
@@ -12,27 +13,26 @@ const GpuMonitor: React.FC = () => {
|
||||
|
||||
useEffect(() => {
|
||||
const fetchGpuInfo = async () => {
|
||||
try {
|
||||
if (isFetchingGpuRef.current) {
|
||||
return;
|
||||
}
|
||||
isFetchingGpuRef.current = true;
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
setGpuData(data);
|
||||
setLastUpdated(new Date());
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
} finally {
|
||||
isFetchingGpuRef.current = false;
|
||||
setLoading(false);
|
||||
if (isFetchingGpuRef.current) {
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
isFetchingGpuRef.current = true;
|
||||
apiClient
|
||||
.get('/api/gpu')
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
setGpuData(data);
|
||||
setLastUpdated(new Date());
|
||||
setError(null);
|
||||
})
|
||||
.catch(err => {
|
||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||
})
|
||||
.finally(() => {
|
||||
isFetchingGpuRef.current = false;
|
||||
setLoading(false);
|
||||
});
|
||||
};
|
||||
|
||||
// Fetch immediately on component mount
|
||||
@@ -69,46 +69,63 @@ const GpuMonitor: React.FC = () => {
|
||||
}
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return <Loading />;
|
||||
}
|
||||
console.log('state', {
|
||||
loading,
|
||||
gpuData,
|
||||
error,
|
||||
lastUpdated,
|
||||
});
|
||||
|
||||
const content = useMemo(() => {
|
||||
if (loading && !gpuData) {
|
||||
return <Loading />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="bg-red-900 border border-red-600 text-red-200 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">Error!</strong>
|
||||
<span className="block sm:inline"> {error}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData) {
|
||||
return (
|
||||
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPU data available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData.hasNvidiaSmi) {
|
||||
return (
|
||||
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
||||
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
||||
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (gpuData.gpus.length === 0) {
|
||||
return (
|
||||
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const gridClass = getGridClasses(gpuData?.gpus?.length || 1);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">Error!</strong>
|
||||
<span className="block sm:inline"> {error}</span>
|
||||
<div className={`grid ${gridClass} gap-3`}>
|
||||
{gpuData.gpus.map((gpu, idx) => (
|
||||
<GPUWidget key={idx} gpu={gpu} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPU data available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!gpuData.hasNvidiaSmi) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
||||
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
||||
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (gpuData.gpus.length === 0) {
|
||||
return (
|
||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
||||
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const gridClass = getGridClasses(gpuData.gpus.length);
|
||||
}, [loading, gpuData, error]);
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
@@ -116,12 +133,7 @@ const GpuMonitor: React.FC = () => {
|
||||
<h1 className="text-md">GPU Monitor</h1>
|
||||
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
||||
</div>
|
||||
|
||||
<div className={`grid ${gridClass} gap-3`}>
|
||||
{gpuData.gpus.map((gpu, idx) => (
|
||||
<GPUWidget key={idx} gpu={gpu} />
|
||||
))}
|
||||
</div>
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -24,9 +24,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
||||
#{gpu.index}
|
||||
</span>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">#{gpu.index}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -38,18 +36,14 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Temperature</p>
|
||||
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>
|
||||
{gpu.temperature}°C
|
||||
</p>
|
||||
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>{gpu.temperature}°C</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Fan className="w-4 h-4 text-blue-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Fan Speed</p>
|
||||
<p className="text-sm font-medium text-blue-400">
|
||||
{gpu.fan.speed}%
|
||||
</p>
|
||||
<p className="text-sm font-medium text-blue-400">{gpu.fan.speed}%</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -99,7 +93,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
<p className="text-xs text-gray-400">Power Draw</p>
|
||||
<p className="text-sm text-gray-200">
|
||||
{gpu.power.draw?.toFixed(1)}W
|
||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
|
||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || ' ? '}W</span>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -107,4 +101,4 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,9 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
{/* Job Information Panel */}
|
||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
|
||||
<h2 className="text-gray-100">
|
||||
<Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}
|
||||
</h2>
|
||||
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
|
||||
</div>
|
||||
|
||||
@@ -85,7 +87,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
<Gauge className="w-5 h-5 text-green-400" />
|
||||
<div>
|
||||
<p className="text-xs text-gray-400">Speed</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
|
||||
<p className="text-sm font-medium text-gray-200">{job.speed_string == '' ? '?' : job.speed_string}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export default function Loading() {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-64">
|
||||
<div className="flex justify-center items-center h-42">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -11,7 +11,14 @@ interface SampleImageCardProps {
|
||||
onDelete?: () => void;
|
||||
}
|
||||
|
||||
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => {
|
||||
const SampleImageCard: React.FC<SampleImageCardProps> = ({
|
||||
imageUrl,
|
||||
alt,
|
||||
numSamples,
|
||||
sampleImages,
|
||||
children,
|
||||
className = '',
|
||||
}) => {
|
||||
const cardRef = useRef<HTMLDivElement>(null);
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
const [loaded, setLoaded] = useState<boolean>(false);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import Link from 'next/link';
|
||||
import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
|
||||
import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react';
|
||||
|
||||
const Sidebar = () => {
|
||||
const navigation = [
|
||||
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
||||
{ name: 'New Job', href: '/jobs/new', icon: Plus },
|
||||
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
|
||||
{ name: 'Datasets', href: '/datasets', icon: Images },
|
||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||
@@ -33,7 +34,7 @@ const Sidebar = () => {
|
||||
</ul>
|
||||
</nav>
|
||||
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
|
||||
<div className='min-w-[26px] min-h-[26px]'>
|
||||
<div className="min-w-[26px] min-h-[26px]">
|
||||
<svg
|
||||
viewBox="0 0 512 512"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
|
||||
@@ -49,7 +49,7 @@ export interface NumberInputProps extends InputProps {
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
|
||||
@@ -66,7 +66,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
value={inputValue}
|
||||
onChange={e => {
|
||||
const rawValue = e.target.value;
|
||||
|
||||
|
||||
// Update the input display with the raw value
|
||||
setInputValue(rawValue);
|
||||
|
||||
@@ -81,7 +81,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
// Only apply constraints and call onChange when we have a valid number
|
||||
if (!isNaN(numValue)) {
|
||||
let constrainedValue = numValue;
|
||||
|
||||
|
||||
// Apply min/max constraints if they exist
|
||||
if (min !== undefined && constrainedValue < min) {
|
||||
constrainedValue = min;
|
||||
@@ -89,7 +89,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
||||
if (max !== undefined && constrainedValue > max) {
|
||||
constrainedValue = max;
|
||||
}
|
||||
|
||||
|
||||
onChange(constrainedValue);
|
||||
}
|
||||
}}
|
||||
@@ -152,14 +152,14 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
className={classNames(
|
||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
|
||||
checked ? 'bg-blue-600' : 'bg-gray-700',
|
||||
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
|
||||
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80',
|
||||
)}
|
||||
>
|
||||
<span className="sr-only">Toggle {label}</span>
|
||||
<span
|
||||
className={classNames(
|
||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||
checked ? 'translate-x-5' : 'translate-x-0'
|
||||
checked ? 'translate-x-5' : 'translate-x-0',
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
@@ -168,7 +168,7 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'text-sm font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300'
|
||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function useDatasetList() {
|
||||
const [datasets, setDatasets] = useState<string[]>([]);
|
||||
@@ -8,8 +9,9 @@ export default function useDatasetList() {
|
||||
|
||||
const refreshDatasets = () => {
|
||||
setStatus('loading');
|
||||
fetch('/api/datasets/list')
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get('/api/datasets/list')
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Datasets:', data);
|
||||
// sort
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, useRef } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
interface FileObject {
|
||||
path: string;
|
||||
@@ -18,8 +19,9 @@ export default function useFilesList(jobID: string, reloadInterval: null | numbe
|
||||
loadStatus = 'refreshing';
|
||||
}
|
||||
setStatus(loadStatus);
|
||||
fetch(`/api/jobs/${jobID}/files`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs/${jobID}/files`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Fetched files:', data);
|
||||
if (data.files) {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { GPUApiResponse, GpuInfo } from '@/types';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
|
||||
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
|
||||
@@ -11,18 +12,11 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
|
||||
const fetchGpuInfo = async () => {
|
||||
setStatus('loading');
|
||||
try {
|
||||
const response = await fetch('/api/gpu');
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data: GPUApiResponse = await response.json();
|
||||
const data: GPUApiResponse = await apiClient.get('/api/gpu').then(res => res.data);
|
||||
let gpus = data.gpus.sort((a, b) => a.index - b.index);
|
||||
if (gpuIds) {
|
||||
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
|
||||
}
|
||||
|
||||
setGpuList(gpus);
|
||||
setStatus('success');
|
||||
} catch (err) {
|
||||
@@ -51,4 +45,4 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
|
||||
}, [gpuIds, reloadInterval]); // Added dependencies
|
||||
|
||||
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function useJob(jobID: string, reloadInterval: null | number = null) {
|
||||
const [job, setJob] = useState<Job | null>(null);
|
||||
@@ -9,8 +10,9 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
||||
|
||||
const refreshJob = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs?id=${jobID}`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs?id=${jobID}`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Job:', data);
|
||||
setJob(data);
|
||||
@@ -32,7 +34,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
}
|
||||
};
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function useJobsList(onlyActive = false) {
|
||||
const [jobs, setJobs] = useState<Job[]>([]);
|
||||
@@ -9,8 +10,9 @@ export default function useJobsList(onlyActive = false) {
|
||||
|
||||
const refreshJobs = () => {
|
||||
setStatus('loading');
|
||||
fetch('/api/jobs')
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get('/api/jobs')
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Jobs:', data);
|
||||
if (data.error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
||||
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
||||
@@ -8,9 +9,11 @@ export default function useSampleImages(jobID: string, reloadInterval: null | nu
|
||||
|
||||
const refreshSampleImages = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs/${jobID}/samples`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs/${jobID}/samples`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Fetched sample images:', data);
|
||||
if (data.samples) {
|
||||
setSampleImages(data.samples);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export interface Settings {
|
||||
HF_TOKEN: string;
|
||||
@@ -16,10 +17,11 @@ export default function useSettings() {
|
||||
});
|
||||
const [isSettingsLoaded, setIsLoaded] = useState(false);
|
||||
useEffect(() => {
|
||||
// Fetch current settings
|
||||
fetch('/api/settings')
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get('/api/settings')
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Settings:', data);
|
||||
setSettings({
|
||||
HF_TOKEN: data.HF_TOKEN || '',
|
||||
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
|
||||
|
||||
49
ui/src/middleware.ts
Normal file
49
ui/src/middleware.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
// middleware.ts (at the root of your project)
|
||||
import { NextResponse } from 'next/server';
|
||||
import type { NextRequest } from 'next/server';
|
||||
|
||||
// if route starts with these, approve
|
||||
const publicRoutes = ['/api/img/', '/api/files/'];
|
||||
|
||||
export function middleware(request: NextRequest) {
|
||||
// check env var for AI_TOOLKIT_AUTH, if not set, approve all requests
|
||||
// if it is set make sure bearer token matches
|
||||
const tokenToUse = process.env.AI_TOOLKIT_AUTH || null;
|
||||
if (!tokenToUse) {
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
// Get the token from the headers
|
||||
const token = request.headers.get('Authorization')?.split(' ')[1];
|
||||
|
||||
// allow public routes to pass through
|
||||
if (publicRoutes.some(route => request.nextUrl.pathname.startsWith(route))) {
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
// Check if the route should be protected
|
||||
// This will apply to all API routes that start with /api/
|
||||
if (request.nextUrl.pathname.startsWith('/api/')) {
|
||||
if (!token || token !== tokenToUse) {
|
||||
// Return a JSON response with 401 Unauthorized
|
||||
return new NextResponse(JSON.stringify({ error: 'Unauthorized' }), {
|
||||
status: 401,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
});
|
||||
}
|
||||
|
||||
// For authorized users, continue
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
// For non-API routes, just continue
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
// Configure which paths this middleware will run on
|
||||
export const config = {
|
||||
matcher: [
|
||||
// Apply to all API routes
|
||||
'/api/:path*',
|
||||
],
|
||||
};
|
||||
31
ui/src/utils/api.ts
Normal file
31
ui/src/utils/api.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import axios from 'axios';
|
||||
import { createGlobalState } from 'react-global-hooks';
|
||||
|
||||
export const isAuthorizedState = createGlobalState(false);
|
||||
|
||||
export const apiClient = axios.create();
|
||||
|
||||
// Add a request interceptor to add token from localStorage
|
||||
apiClient.interceptors.request.use(config => {
|
||||
const token = localStorage.getItem('AI_TOOLKIT_AUTH');
|
||||
if (token) {
|
||||
config.headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
});
|
||||
|
||||
// Add a response interceptor to handle 401 errors
|
||||
apiClient.interceptors.response.use(
|
||||
response => response, // Return successful responses as-is
|
||||
error => {
|
||||
// Check if the error is a 401 Unauthorized
|
||||
if (error.response && error.response.status === 401) {
|
||||
// Clear the auth token from localStorage
|
||||
localStorage.removeItem('AI_TOOLKIT_AUTH');
|
||||
isAuthorizedState.set(false);
|
||||
}
|
||||
|
||||
// Reject the promise with the error so calling code can still catch it
|
||||
return Promise.reject(error);
|
||||
},
|
||||
);
|
||||
@@ -2,3 +2,4 @@ export const objectCopy = <T>(obj: T): T => {
|
||||
return JSON.parse(JSON.stringify(obj)) as T;
|
||||
};
|
||||
|
||||
export const wait = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
|
||||
|
||||
@@ -79,10 +79,10 @@ export function useNestedState<T>(initialState: T): [T, (value: any, path?: stri
|
||||
const setValue = React.useCallback((value: any, path?: string) => {
|
||||
if (path === undefined) {
|
||||
setState(value);
|
||||
return
|
||||
return;
|
||||
}
|
||||
setState(prevState => setNestedValue(prevState, value, path));
|
||||
}, []);
|
||||
|
||||
return [state, setValue];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { JobConfig } from '@/types';
|
||||
import { Job } from '@prisma/client';
|
||||
import { apiClient } from '@/utils/api';
|
||||
|
||||
export const startJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/start`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs/${jobID}/start`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Job started:', data);
|
||||
resolve();
|
||||
@@ -18,8 +20,9 @@ export const startJob = (jobID: string) => {
|
||||
|
||||
export const stopJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/stop`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs/${jobID}/stop`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Job stopped:', data);
|
||||
resolve();
|
||||
@@ -33,8 +36,9 @@ export const stopJob = (jobID: string) => {
|
||||
|
||||
export const deleteJob = (jobID: string) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
fetch(`/api/jobs/${jobID}/delete`)
|
||||
.then(res => res.json())
|
||||
apiClient
|
||||
.get(`/api/jobs/${jobID}/delete`)
|
||||
.then(res => res.data)
|
||||
.then(data => {
|
||||
console.log('Job deleted:', data);
|
||||
resolve();
|
||||
@@ -67,9 +71,9 @@ export const getAvaliableJobActions = (job: Job) => {
|
||||
export const getNumberOfSamples = (job: Job) => {
|
||||
const jobConfig = getJobConfig(job);
|
||||
return jobConfig.config.process[0].sample?.prompts?.length || 0;
|
||||
}
|
||||
};
|
||||
|
||||
export const getTotalSteps = (job: Job) => {
|
||||
const jobConfig = getJobConfig(job);
|
||||
return jobConfig.config.process[0].train.steps;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import type { Config } from "tailwindcss";
|
||||
import type { Config } from 'tailwindcss';
|
||||
|
||||
const config: Config = {
|
||||
content: [
|
||||
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
|
||||
'./src/pages/**/*.{js,ts,jsx,tsx,mdx}',
|
||||
'./src/components/**/*.{js,ts,jsx,tsx,mdx}',
|
||||
'./src/app/**/*.{js,ts,jsx,tsx,mdx}',
|
||||
],
|
||||
darkMode: "class",
|
||||
darkMode: 'class',
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
gray: {
|
||||
950: "#0a0a0a",
|
||||
900: "#171717",
|
||||
800: "#262626",
|
||||
700: "#404040",
|
||||
600: "#525252",
|
||||
500: "#737373",
|
||||
400: "#a3a3a3",
|
||||
300: "#d4d4d4",
|
||||
200: "#e5e5e5",
|
||||
100: "#f5f5f5",
|
||||
950: '#0a0a0a',
|
||||
900: '#171717',
|
||||
800: '#262626',
|
||||
700: '#404040',
|
||||
600: '#525252',
|
||||
500: '#737373',
|
||||
400: '#a3a3a3',
|
||||
300: '#d4d4d4',
|
||||
200: '#e5e5e5',
|
||||
100: '#f5f5f5',
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -28,4 +28,4 @@ const config: Config = {
|
||||
plugins: [],
|
||||
};
|
||||
|
||||
export default config;
|
||||
export default config;
|
||||
|
||||
Reference in New Issue
Block a user