mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Samples work in ui now
This commit is contained in:
20
ui/package-lock.json
generated
20
ui/package-lock.json
generated
@@ -14,6 +14,7 @@
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.7",
|
||||
"node-cache": "^5.1.2",
|
||||
"prisma": "^6.3.1",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
@@ -1414,6 +1415,14 @@
|
||||
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
|
||||
"integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA=="
|
||||
},
|
||||
"node_modules/clone": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz",
|
||||
"integrity": "sha512-3Pe/CF1Nn94hyhIYpjtiLhdCoEoz0DqQ+988E9gmeEdQZlojxnOb74wctFyuwWQHzqyf9X7C7MG8juUpqBJT8w==",
|
||||
"engines": {
|
||||
"node": ">=0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/clsx": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
|
||||
@@ -2891,6 +2900,17 @@
|
||||
"integrity": "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/node-cache": {
|
||||
"version": "5.1.2",
|
||||
"resolved": "https://registry.npmjs.org/node-cache/-/node-cache-5.1.2.tgz",
|
||||
"integrity": "sha512-t1QzWwnk4sjLWaQAS8CHgOJ+RAfmHpxFWmc36IWTiWHQfs0w5JDMBS1b1ZxQteo0vVVuWJvIUKHDkkeK7vIGCg==",
|
||||
"dependencies": {
|
||||
"clone": "2.x"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 8.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/node-gyp": {
|
||||
"version": "8.4.1",
|
||||
"resolved": "https://registry.npmjs.org/node-gyp/-/node-gyp-8.4.1.tgz",
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
"classnames": "^2.5.1",
|
||||
"lucide-react": "^0.475.0",
|
||||
"next": "15.1.7",
|
||||
"node-cache": "^5.1.2",
|
||||
"prisma": "^6.3.1",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
|
||||
const { imagePath } = await params;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const datasetsPath = await getDatasetsRoot();
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultDatasetsFolder } from '@/paths';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export const getDatasetsRoot = async () => {
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: 'DATASETS_FOLDER',
|
||||
},
|
||||
});
|
||||
let datasetsPath = defaultDatasetsFolder;
|
||||
if (row?.value && row.value !== '') {
|
||||
datasetsPath = row.value;
|
||||
}
|
||||
return datasetsPath;
|
||||
};
|
||||
@@ -2,23 +2,25 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
|
||||
const { imagePath } = await params;
|
||||
try {
|
||||
// Decode the path
|
||||
const filepath = decodeURIComponent(imagePath);
|
||||
console.log('Serving image:', filepath);
|
||||
|
||||
// Get allowed directories
|
||||
const allowedDir = await getDatasetsRoot();
|
||||
const datasetRoot = await getDatasetsRoot();
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
|
||||
const allowedDirs = [datasetRoot, trainingRoot];
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
|
||||
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
|
||||
console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
|
||||
return new NextResponse('Access denied', { status: 403 });
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import { getDatasetsRoot } from '@/app/api/datasets/utils';
|
||||
import { getDatasetsRoot } from '@/server/settings';
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
|
||||
40
ui/src/app/api/jobs/[jobID]/samples/route.ts
Normal file
40
ui/src/app/api/jobs/[jobID]/samples/route.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
// setup the training
|
||||
const trainingFolder = await getTrainingFolder();
|
||||
|
||||
const samplesFolder = path.join(trainingFolder, job.name, 'samples');
|
||||
if (!fs.existsSync(samplesFolder)) {
|
||||
return NextResponse.json({ samples: [] });
|
||||
}
|
||||
|
||||
// find all img (png, jpg, jpeg) files in the samples folder
|
||||
const samples = fs
|
||||
.readdirSync(samplesFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(samplesFolder, file);
|
||||
})
|
||||
.sort();
|
||||
|
||||
return NextResponse.json({ samples });
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
|
||||
import { spawn } from 'child_process';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
@@ -30,18 +30,10 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
});
|
||||
|
||||
// setup the training
|
||||
const settings = await prisma.settings.findMany();
|
||||
const settingsObject = settings.reduce((acc: any, setting) => {
|
||||
acc[setting.key] = setting.value;
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
// if TRAINING_FOLDER is not set, use default
|
||||
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
|
||||
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
|
||||
}
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
|
||||
const trainingFolder = path.join(settingsObject.TRAINING_FOLDER, job.name);
|
||||
const trainingFolder = path.join(trainingRoot, job.name);
|
||||
if (!fs.existsSync(trainingFolder)) {
|
||||
fs.mkdirSync(trainingFolder, { recursive: true });
|
||||
}
|
||||
@@ -53,7 +45,6 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
const jobConfig = JSON.parse(job.job_config);
|
||||
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
|
||||
|
||||
|
||||
// write the config file
|
||||
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
|
||||
|
||||
@@ -70,7 +61,10 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
|
||||
}
|
||||
|
||||
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`);
|
||||
console.log(
|
||||
'Spawning command:',
|
||||
`AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
|
||||
);
|
||||
|
||||
// start job
|
||||
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||
@@ -83,7 +77,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
},
|
||||
cwd: TOOLKIT_ROOT,
|
||||
});
|
||||
|
||||
|
||||
subprocess.unref();
|
||||
|
||||
return NextResponse.json(job);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
||||
import {flushCache} from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
@@ -49,6 +50,8 @@ export async function POST(request: Request) {
|
||||
}),
|
||||
]);
|
||||
|
||||
flushCache();
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
|
||||
|
||||
@@ -1,16 +1,41 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, use } from 'react';
|
||||
import { useMemo, useState, use } from 'react';
|
||||
import { FaChevronLeft } from 'react-icons/fa';
|
||||
import { Button } from '@headlessui/react';
|
||||
import { TopBar, MainContent } from '@/components/layout';
|
||||
import useJob from '@/hooks/useJob';
|
||||
import { startJob, stopJob } from '@/utils/jobs';
|
||||
import SampleImages from '@/components/SampleImages';
|
||||
import JobOverview from '@/components/JobOverview';
|
||||
import { JobConfig } from '@/types';
|
||||
|
||||
type PageKey = 'overview' | 'samples';
|
||||
|
||||
interface Page {
|
||||
name: string;
|
||||
value: PageKey;
|
||||
}
|
||||
|
||||
const pages: Page[] = [
|
||||
{ name: 'Overview', value: 'overview' },
|
||||
{ name: 'Samples', value: 'samples' },
|
||||
];
|
||||
|
||||
export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
const usableParams = use(params as any) as { jobID: string };
|
||||
const jobID = usableParams.jobID;
|
||||
const { job, status, refreshJobs } = useJob(jobID, 5000);
|
||||
const { job, status, refreshJob } = useJob(jobID, 5000);
|
||||
const [pageKey, setPageKey] = useState<PageKey>('overview');
|
||||
|
||||
const numSamples = useMemo(() => {
|
||||
if (job?.job_config) {
|
||||
const jobConfig = JSON.parse(job.job_config) as JobConfig;
|
||||
const sampleConfig = jobConfig.config.process[0].sample;
|
||||
return sampleConfig.prompts.length;
|
||||
}
|
||||
return 10;
|
||||
}, [job]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -28,8 +53,8 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
{job?.status === 'running' && (
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await stopJob(jobID);
|
||||
refreshJobs();
|
||||
await stopJob(jobID);
|
||||
refreshJob();
|
||||
}}
|
||||
className="bg-red-500 text-white px-4 py-1 rounded-sm"
|
||||
>
|
||||
@@ -39,8 +64,8 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
{(job?.status === 'stopped' || job?.status === 'error') && (
|
||||
<Button
|
||||
onClick={async () => {
|
||||
await startJob(jobID);
|
||||
refreshJobs();
|
||||
await startJob(jobID);
|
||||
refreshJob();
|
||||
}}
|
||||
className="bg-green-800 text-white px-4 py-1 rounded-sm"
|
||||
>
|
||||
@@ -48,25 +73,27 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
|
||||
</Button>
|
||||
)}
|
||||
</TopBar>
|
||||
<MainContent>
|
||||
<MainContent className="pt-24">
|
||||
{status === 'loading' && job == null && <p>Loading...</p>}
|
||||
{status === 'error' && job == null && <p>Error fetching job</p>}
|
||||
{job && (
|
||||
<>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||
<div className="">
|
||||
<h2 className="text-lg font-semibold">Job Details</h2>
|
||||
<p className="text-gray-400">ID: {job.id}</p>
|
||||
<p className="text-gray-400">Name: {job.name}</p>
|
||||
<p className="text-gray-400">GPUs: {job.gpu_ids}</p>
|
||||
<p className="text-gray-400">Status: {job.status}</p>
|
||||
<p className="text-gray-400">Info: {job.info}</p>
|
||||
<p className="text-gray-400">Step: {job.step}</p>
|
||||
</div>
|
||||
</div>
|
||||
{pageKey === 'overview' && <JobOverview job={job} />}
|
||||
{pageKey === 'samples' && <SampleImages job={job} />}
|
||||
</>
|
||||
)}
|
||||
</MainContent>
|
||||
<div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
|
||||
{pages.map(page => (
|
||||
<Button
|
||||
key={page.value}
|
||||
onClick={() => setPageKey(page.value)}
|
||||
className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
|
||||
>
|
||||
{page.name}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ export default function RootLayout({ children }: { children: React.ReactNode })
|
||||
<ThemeProvider>
|
||||
<div className="flex h-screen bg-gray-950">
|
||||
<Sidebar />
|
||||
<main className="flex-1 p-8 overflow-auto bg-gray-950 text-gray-100 relative">{children}</main>
|
||||
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">{children}</main>
|
||||
</div>
|
||||
</ThemeProvider>
|
||||
<ConfirmModal />
|
||||
|
||||
23
ui/src/components/JobOverview.tsx
Normal file
23
ui/src/components/JobOverview.tsx
Normal file
@@ -0,0 +1,23 @@
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
interface JobOverviewProps {
|
||||
job: Job;
|
||||
}
|
||||
|
||||
export default function JobOverview({ job }: JobOverviewProps) {
|
||||
return (
|
||||
<>
|
||||
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
|
||||
<div className="">
|
||||
<h2 className="text-lg font-semibold">Job Details</h2>
|
||||
<p className="text-gray-400">ID: {job.id}</p>
|
||||
<p className="text-gray-400">Name: {job.name}</p>
|
||||
<p className="text-gray-400">GPUs: {job.gpu_ids}</p>
|
||||
<p className="text-gray-400">Status: {job.status}</p>
|
||||
<p className="text-gray-400">Info: {job.info}</p>
|
||||
<p className="text-gray-400">Step: {job.step}</p>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
67
ui/src/components/SampleImageCard.tsx
Normal file
67
ui/src/components/SampleImageCard.tsx
Normal file
@@ -0,0 +1,67 @@
|
||||
import React, { useRef, useEffect, useState, ReactNode } from 'react';
|
||||
|
||||
interface SampleImageCardProps {
|
||||
imageUrl: string;
|
||||
alt: string;
|
||||
children?: ReactNode;
|
||||
className?: string;
|
||||
onDelete?: () => void;
|
||||
}
|
||||
|
||||
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, children, className = '' }) => {
|
||||
const cardRef = useRef<HTMLDivElement>(null);
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
const [loaded, setLoaded] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Create intersection observer to check visibility
|
||||
const observer = new IntersectionObserver(
|
||||
entries => {
|
||||
if (entries[0].isIntersecting) {
|
||||
setIsVisible(true);
|
||||
observer.disconnect();
|
||||
}
|
||||
},
|
||||
{ threshold: 0.1 },
|
||||
);
|
||||
|
||||
if (cardRef.current) {
|
||||
observer.observe(cardRef.current);
|
||||
}
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const handleLoad = (): void => {
|
||||
setLoaded(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`flex flex-col ${className}`}>
|
||||
{/* Square image container */}
|
||||
<div
|
||||
ref={cardRef}
|
||||
className="relative w-full"
|
||||
style={{ paddingBottom: '100%' }} // Make it square
|
||||
>
|
||||
<div className="absolute inset-0 rounded-t-lg shadow-md">
|
||||
{isVisible && (
|
||||
<img
|
||||
src={`/api/img/${encodeURIComponent(imageUrl)}`}
|
||||
alt={alt}
|
||||
onLoad={handleLoad}
|
||||
className={`w-full h-full object-contain transition-opacity duration-300 ${
|
||||
loaded ? 'opacity-100' : 'opacity-0'
|
||||
}`}
|
||||
/>
|
||||
)}
|
||||
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SampleImageCard;
|
||||
89
ui/src/components/SampleImages.tsx
Normal file
89
ui/src/components/SampleImages.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
import { useMemo } from 'react';
|
||||
import useSampleImages from '@/hooks/useSampleImages';
|
||||
import SampleImageCard from './SampleImageCard';
|
||||
import { Job } from '@prisma/client';
|
||||
import { JobConfig } from '@/types';
|
||||
|
||||
interface SampleImagesProps {
|
||||
job: Job;
|
||||
}
|
||||
|
||||
export default function SampleImages({ job }: SampleImagesProps) {
|
||||
const { sampleImages, status, refreshSampleImages } = useSampleImages(job.id, 5000);
|
||||
const numSamples = useMemo(() => {
|
||||
if (job?.job_config) {
|
||||
const jobConfig = JSON.parse(job.job_config) as JobConfig;
|
||||
const sampleConfig = jobConfig.config.process[0].sample;
|
||||
return sampleConfig.prompts.length;
|
||||
}
|
||||
return 10;
|
||||
}, [job]);
|
||||
|
||||
// Use direct Tailwind class without string interpolation
|
||||
// This way Tailwind can properly generate the class
|
||||
// I hate this, but it's the only way to make it work
|
||||
const gridColsClass = useMemo(() => {
|
||||
const cols = Math.min(numSamples, 20);
|
||||
|
||||
switch (cols) {
|
||||
case 1:
|
||||
return 'grid-cols-1';
|
||||
case 2:
|
||||
return 'grid-cols-2';
|
||||
case 3:
|
||||
return 'grid-cols-3';
|
||||
case 4:
|
||||
return 'grid-cols-4';
|
||||
case 5:
|
||||
return 'grid-cols-5';
|
||||
case 6:
|
||||
return 'grid-cols-6';
|
||||
case 7:
|
||||
return 'grid-cols-7';
|
||||
case 8:
|
||||
return 'grid-cols-8';
|
||||
case 9:
|
||||
return 'grid-cols-9';
|
||||
case 10:
|
||||
return 'grid-cols-10';
|
||||
case 11:
|
||||
return 'grid-cols-11';
|
||||
case 12:
|
||||
return 'grid-cols-12';
|
||||
case 13:
|
||||
return 'grid-cols-13';
|
||||
case 14:
|
||||
return 'grid-cols-14';
|
||||
case 15:
|
||||
return 'grid-cols-15';
|
||||
case 16:
|
||||
return 'grid-cols-16';
|
||||
case 17:
|
||||
return 'grid-cols-17';
|
||||
case 18:
|
||||
return 'grid-cols-18';
|
||||
case 19:
|
||||
return 'grid-cols-19';
|
||||
case 20:
|
||||
return 'grid-cols-20';
|
||||
default:
|
||||
return 'grid-cols-1';
|
||||
}
|
||||
}, [numSamples]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className='pb-4'>
|
||||
{status === 'loading' && sampleImages.length === 0 && <p>Loading...</p>}
|
||||
{status === 'error' && <p>Error fetching sample images</p>}
|
||||
{sampleImages && (
|
||||
<div className={`grid ${gridColsClass} gap-1`}>
|
||||
{sampleImages.map((sample: string) => (
|
||||
<SampleImageCard key={sample} imageUrl={sample} alt="Sample Image" />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -20,7 +20,7 @@ export const TopBar: React.FC<Props> = ({ children, className }) => {
|
||||
|
||||
export const MainContent: React.FC<Props> = ({ children, className }) => {
|
||||
return (
|
||||
<div className={classNames('pt-16 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
|
||||
<div className={classNames('pt-14 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
|
||||
{children ? children : null}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
||||
const [job, setJob] = useState<Job | null>(null);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshJobs = () => {
|
||||
const refreshJob = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs?id=${jobID}`)
|
||||
.then(res => res.json())
|
||||
@@ -23,11 +23,11 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshJobs();
|
||||
refreshJob();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshJobs();
|
||||
refreshJob();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
@@ -36,5 +36,5 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { job, setJob, status, refreshJobs };
|
||||
return { job, setJob, status, refreshJob };
|
||||
}
|
||||
|
||||
41
ui/src/hooks/useSampleImages.tsx
Normal file
41
ui/src/hooks/useSampleImages.tsx
Normal file
@@ -0,0 +1,41 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
||||
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
|
||||
|
||||
const refreshSampleImages = () => {
|
||||
setStatus('loading');
|
||||
fetch(`/api/jobs/${jobID}/samples`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
if (data.samples) {
|
||||
setSampleImages(data.samples);
|
||||
}
|
||||
setStatus('success');
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshSampleImages();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshSampleImages();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
};
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { sampleImages, setSampleImages, status, refreshSampleImages };
|
||||
}
|
||||
50
ui/src/server/settings.ts
Normal file
50
ui/src/server/settings.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import { defaultDatasetsFolder } from '@/paths';
|
||||
import { defaultTrainFolder } from '@/paths';
|
||||
import NodeCache from 'node-cache';
|
||||
|
||||
const myCache = new NodeCache();
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export const flushCache = () => {
|
||||
myCache.flushAll();
|
||||
};
|
||||
|
||||
export const getDatasetsRoot = async () => {
|
||||
const key = 'DATASETS_FOLDER';
|
||||
let datasetsPath = myCache.get(key) as string;
|
||||
if (datasetsPath) {
|
||||
return datasetsPath;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: 'DATASETS_FOLDER',
|
||||
},
|
||||
});
|
||||
datasetsPath = defaultDatasetsFolder;
|
||||
if (row?.value && row.value !== '') {
|
||||
datasetsPath = row.value;
|
||||
}
|
||||
myCache.set(key, datasetsPath);
|
||||
return datasetsPath as string;
|
||||
};
|
||||
|
||||
|
||||
export const getTrainingFolder = async () => {
|
||||
const key = 'TRAINING_FOLDER';
|
||||
let trainingRoot = myCache.get(key) as string;
|
||||
if (trainingRoot) {
|
||||
return trainingRoot;
|
||||
}
|
||||
let row = await prisma.settings.findFirst({
|
||||
where: {
|
||||
key: key,
|
||||
},
|
||||
});
|
||||
trainingRoot = defaultTrainFolder;
|
||||
if (row?.value && row.value !== '') {
|
||||
trainingRoot = row.value;
|
||||
}
|
||||
myCache.set(key, trainingRoot);
|
||||
return trainingRoot as string;
|
||||
};
|
||||
Reference in New Issue
Block a user