Samples work in ui now

This commit is contained in:
Jaret Burkett
2025-02-21 20:28:52 -07:00
parent 2b6e66e0cb
commit 710c6de1c9
24 changed files with 408 additions and 68 deletions

20
ui/package-lock.json generated
View File

@@ -14,6 +14,7 @@
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.7",
"node-cache": "^5.1.2",
"prisma": "^6.3.1",
"react": "^19.0.0",
"react-dom": "^19.0.0",
@@ -1414,6 +1415,14 @@
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
"integrity": "sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA=="
},
"node_modules/clone": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/clone/-/clone-2.1.2.tgz",
"integrity": "sha512-3Pe/CF1Nn94hyhIYpjtiLhdCoEoz0DqQ+988E9gmeEdQZlojxnOb74wctFyuwWQHzqyf9X7C7MG8juUpqBJT8w==",
"engines": {
"node": ">=0.8"
}
},
"node_modules/clsx": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
@@ -2891,6 +2900,17 @@
"integrity": "sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==",
"license": "MIT"
},
"node_modules/node-cache": {
"version": "5.1.2",
"resolved": "https://registry.npmjs.org/node-cache/-/node-cache-5.1.2.tgz",
"integrity": "sha512-t1QzWwnk4sjLWaQAS8CHgOJ+RAfmHpxFWmc36IWTiWHQfs0w5JDMBS1b1ZxQteo0vVVuWJvIUKHDkkeK7vIGCg==",
"dependencies": {
"clone": "2.x"
},
"engines": {
"node": ">= 8.0.0"
}
},
"node_modules/node-gyp": {
"version": "8.4.1",
"resolved": "https://registry.npmjs.org/node-gyp/-/node-gyp-8.4.1.tgz",

View File

@@ -17,6 +17,7 @@
"classnames": "^2.5.1",
"lucide-react": "^0.475.0",
"next": "15.1.7",
"node-cache": "^5.1.2",
"prisma": "^6.3.1",
"react": "^19.0.0",
"react-dom": "^19.0.0",

View File

@@ -2,7 +2,7 @@
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
const { imagePath } = await params;

View File

@@ -1,7 +1,7 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {

View File

@@ -1,7 +1,7 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {

View File

@@ -1,6 +1,6 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function GET() {
try {

View File

@@ -1,7 +1,7 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
const datasetsPath = await getDatasetsRoot();

View File

@@ -2,7 +2,7 @@
import { NextRequest, NextResponse } from 'next/server';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: NextRequest) {
try {

View File

@@ -1,17 +0,0 @@
import { PrismaClient } from '@prisma/client';
import { defaultDatasetsFolder } from '@/paths';
const prisma = new PrismaClient();
export const getDatasetsRoot = async () => {
let row = await prisma.settings.findFirst({
where: {
key: 'DATASETS_FOLDER',
},
});
let datasetsPath = defaultDatasetsFolder;
if (row?.value && row.value !== '') {
datasetsPath = row.value;
}
return datasetsPath;
};

View File

@@ -2,23 +2,25 @@
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
const { imagePath } = await params;
try {
// Decode the path
const filepath = decodeURIComponent(imagePath);
console.log('Serving image:', filepath);
// Get allowed directories
const allowedDir = await getDatasetsRoot();
const datasetRoot = await getDatasetsRoot();
const trainingRoot = await getTrainingFolder();
const allowedDirs = [datasetRoot, trainingRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = filepath.startsWith(allowedDir) && !filepath.includes('..');
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${filepath} not in ${allowedDir}`);
console.warn(`Access denied: ${filepath} not in ${allowedDirs.join(', ')}`);
return new NextResponse('Access denied', { status: 403 });
}

View File

@@ -1,6 +1,6 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {

View File

@@ -1,6 +1,6 @@
import { NextResponse } from 'next/server';
import fs from 'fs';
import { getDatasetsRoot } from '@/app/api/datasets/utils';
import { getDatasetsRoot } from '@/server/settings';
export async function POST(request: Request) {
try {

View File

@@ -0,0 +1,40 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
// setup the training
const trainingFolder = await getTrainingFolder();
const samplesFolder = path.join(trainingFolder, job.name, 'samples');
if (!fs.existsSync(samplesFolder)) {
return NextResponse.json({ samples: [] });
}
// find all img (png, jpg, jpeg) files in the samples folder
const samples = fs
.readdirSync(samplesFolder)
.filter(file => {
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
})
.map(file => {
return path.join(samplesFolder, file);
})
.sort();
return NextResponse.json({ samples });
}

View File

@@ -4,7 +4,7 @@ import { TOOLKIT_ROOT, defaultTrainFolder } from '@/paths';
import { spawn } from 'child_process';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
const prisma = new PrismaClient();
@@ -30,18 +30,10 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
});
// setup the training
const settings = await prisma.settings.findMany();
const settingsObject = settings.reduce((acc: any, setting) => {
acc[setting.key] = setting.value;
return acc;
}, {});
// if TRAINING_FOLDER is not set, use default
if (!settingsObject.TRAINING_FOLDER || settingsObject.TRAINING_FOLDER === '') {
settingsObject.TRAINING_FOLDER = defaultTrainFolder;
}
const trainingRoot = await getTrainingFolder();
const trainingFolder = path.join(settingsObject.TRAINING_FOLDER, job.name);
const trainingFolder = path.join(trainingRoot, job.name);
if (!fs.existsSync(trainingFolder)) {
fs.mkdirSync(trainingFolder, { recursive: true });
}
@@ -53,7 +45,6 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
const jobConfig = JSON.parse(job.job_config);
jobConfig.config.process[0].sqlite_db_path = path.join(TOOLKIT_ROOT, 'aitk_db.db');
// write the config file
fs.writeFileSync(configPath, JSON.stringify(jobConfig, null, 2));
@@ -70,7 +61,10 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
return NextResponse.json({ error: 'run.py not found' }, { status: 500 });
}
console.log('Spawning command:', `AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`);
console.log(
'Spawning command:',
`AITK_JOB_ID=${jobID} CUDA_VISIBLE_DEVICES=${job.gpu_ids} ${pythonPath} ${runFilePath} ${configPath}`,
);
// start job
const subprocess = spawn(pythonPath, [runFilePath, configPath], {
@@ -83,7 +77,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
},
cwd: TOOLKIT_ROOT,
});
subprocess.unref();
return NextResponse.json(job);

View File

@@ -1,6 +1,7 @@
import { NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
import {flushCache} from '@/server/settings';
const prisma = new PrismaClient();
@@ -49,6 +50,8 @@ export async function POST(request: Request) {
}),
]);
flushCache();
return NextResponse.json({ success: true });
} catch (error) {
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });

View File

@@ -1,16 +1,41 @@
'use client';
import { useEffect, useState, use } from 'react';
import { useMemo, useState, use } from 'react';
import { FaChevronLeft } from 'react-icons/fa';
import { Button } from '@headlessui/react';
import { TopBar, MainContent } from '@/components/layout';
import useJob from '@/hooks/useJob';
import { startJob, stopJob } from '@/utils/jobs';
import SampleImages from '@/components/SampleImages';
import JobOverview from '@/components/JobOverview';
import { JobConfig } from '@/types';
type PageKey = 'overview' | 'samples';
interface Page {
name: string;
value: PageKey;
}
const pages: Page[] = [
{ name: 'Overview', value: 'overview' },
{ name: 'Samples', value: 'samples' },
];
export default function JobPage({ params }: { params: { jobID: string } }) {
const usableParams = use(params as any) as { jobID: string };
const jobID = usableParams.jobID;
const { job, status, refreshJobs } = useJob(jobID, 5000);
const { job, status, refreshJob } = useJob(jobID, 5000);
const [pageKey, setPageKey] = useState<PageKey>('overview');
const numSamples = useMemo(() => {
if (job?.job_config) {
const jobConfig = JSON.parse(job.job_config) as JobConfig;
const sampleConfig = jobConfig.config.process[0].sample;
return sampleConfig.prompts.length;
}
return 10;
}, [job]);
return (
<>
@@ -28,8 +53,8 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
{job?.status === 'running' && (
<Button
onClick={async () => {
await stopJob(jobID);
refreshJobs();
await stopJob(jobID);
refreshJob();
}}
className="bg-red-500 text-white px-4 py-1 rounded-sm"
>
@@ -39,8 +64,8 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
{(job?.status === 'stopped' || job?.status === 'error') && (
<Button
onClick={async () => {
await startJob(jobID);
refreshJobs();
await startJob(jobID);
refreshJob();
}}
className="bg-green-800 text-white px-4 py-1 rounded-sm"
>
@@ -48,25 +73,27 @@ export default function JobPage({ params }: { params: { jobID: string } }) {
</Button>
)}
</TopBar>
<MainContent>
<MainContent className="pt-24">
{status === 'loading' && job == null && <p>Loading...</p>}
{status === 'error' && job == null && <p>Error fetching job</p>}
{job && (
<>
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
<div className="">
<h2 className="text-lg font-semibold">Job Details</h2>
<p className="text-gray-400">ID: {job.id}</p>
<p className="text-gray-400">Name: {job.name}</p>
<p className="text-gray-400">GPUs: {job.gpu_ids}</p>
<p className="text-gray-400">Status: {job.status}</p>
<p className="text-gray-400">Info: {job.info}</p>
<p className="text-gray-400">Step: {job.step}</p>
</div>
</div>
{pageKey === 'overview' && <JobOverview job={job} />}
{pageKey === 'samples' && <SampleImages job={job} />}
</>
)}
</MainContent>
<div className="bg-gray-800 absolute top-12 left-0 w-full h-8 flex items-center px-2 text-sm">
{pages.map(page => (
<Button
key={page.value}
onClick={() => setPageKey(page.value)}
className={`px-4 py-1 h-8 ${page.value === pageKey ? 'bg-gray-300 dark:bg-gray-700' : ''}`}
>
{page.name}
</Button>
))}
</div>
</>
);
}

View File

@@ -19,7 +19,7 @@ export default function RootLayout({ children }: { children: React.ReactNode })
<ThemeProvider>
<div className="flex h-screen bg-gray-950">
<Sidebar />
<main className="flex-1 p-8 overflow-auto bg-gray-950 text-gray-100 relative">{children}</main>
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">{children}</main>
</div>
</ThemeProvider>
<ConfirmModal />

View File

@@ -0,0 +1,23 @@
import { Job } from '@prisma/client';
interface JobOverviewProps {
job: Job;
}
export default function JobOverview({ job }: JobOverviewProps) {
return (
<>
<div className="grid grid-cols-1 gap-4 md:grid-cols-2">
<div className="">
<h2 className="text-lg font-semibold">Job Details</h2>
<p className="text-gray-400">ID: {job.id}</p>
<p className="text-gray-400">Name: {job.name}</p>
<p className="text-gray-400">GPUs: {job.gpu_ids}</p>
<p className="text-gray-400">Status: {job.status}</p>
<p className="text-gray-400">Info: {job.info}</p>
<p className="text-gray-400">Step: {job.step}</p>
</div>
</div>
</>
);
}

View File

@@ -0,0 +1,67 @@
import React, { useRef, useEffect, useState, ReactNode } from 'react';
interface SampleImageCardProps {
imageUrl: string;
alt: string;
children?: ReactNode;
className?: string;
onDelete?: () => void;
}
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, children, className = '' }) => {
const cardRef = useRef<HTMLDivElement>(null);
const [isVisible, setIsVisible] = useState<boolean>(false);
const [loaded, setLoaded] = useState<boolean>(false);
useEffect(() => {
// Create intersection observer to check visibility
const observer = new IntersectionObserver(
entries => {
if (entries[0].isIntersecting) {
setIsVisible(true);
observer.disconnect();
}
},
{ threshold: 0.1 },
);
if (cardRef.current) {
observer.observe(cardRef.current);
}
return () => {
observer.disconnect();
};
}, []);
const handleLoad = (): void => {
setLoaded(true);
};
return (
<div className={`flex flex-col ${className}`}>
{/* Square image container */}
<div
ref={cardRef}
className="relative w-full"
style={{ paddingBottom: '100%' }} // Make it square
>
<div className="absolute inset-0 rounded-t-lg shadow-md">
{isVisible && (
<img
src={`/api/img/${encodeURIComponent(imageUrl)}`}
alt={alt}
onLoad={handleLoad}
className={`w-full h-full object-contain transition-opacity duration-300 ${
loaded ? 'opacity-100' : 'opacity-0'
}`}
/>
)}
{children && <div className="absolute inset-0 flex items-center justify-center">{children}</div>}
</div>
</div>
</div>
);
};
export default SampleImageCard;

View File

@@ -0,0 +1,89 @@
import { useMemo } from 'react';
import useSampleImages from '@/hooks/useSampleImages';
import SampleImageCard from './SampleImageCard';
import { Job } from '@prisma/client';
import { JobConfig } from '@/types';
interface SampleImagesProps {
job: Job;
}
export default function SampleImages({ job }: SampleImagesProps) {
const { sampleImages, status, refreshSampleImages } = useSampleImages(job.id, 5000);
const numSamples = useMemo(() => {
if (job?.job_config) {
const jobConfig = JSON.parse(job.job_config) as JobConfig;
const sampleConfig = jobConfig.config.process[0].sample;
return sampleConfig.prompts.length;
}
return 10;
}, [job]);
// Use direct Tailwind class without string interpolation
// This way Tailwind can properly generate the class
// I hate this, but it's the only way to make it work
const gridColsClass = useMemo(() => {
const cols = Math.min(numSamples, 20);
switch (cols) {
case 1:
return 'grid-cols-1';
case 2:
return 'grid-cols-2';
case 3:
return 'grid-cols-3';
case 4:
return 'grid-cols-4';
case 5:
return 'grid-cols-5';
case 6:
return 'grid-cols-6';
case 7:
return 'grid-cols-7';
case 8:
return 'grid-cols-8';
case 9:
return 'grid-cols-9';
case 10:
return 'grid-cols-10';
case 11:
return 'grid-cols-11';
case 12:
return 'grid-cols-12';
case 13:
return 'grid-cols-13';
case 14:
return 'grid-cols-14';
case 15:
return 'grid-cols-15';
case 16:
return 'grid-cols-16';
case 17:
return 'grid-cols-17';
case 18:
return 'grid-cols-18';
case 19:
return 'grid-cols-19';
case 20:
return 'grid-cols-20';
default:
return 'grid-cols-1';
}
}, [numSamples]);
return (
<div>
<div className='pb-4'>
{status === 'loading' && sampleImages.length === 0 && <p>Loading...</p>}
{status === 'error' && <p>Error fetching sample images</p>}
{sampleImages && (
<div className={`grid ${gridColsClass} gap-1`}>
{sampleImages.map((sample: string) => (
<SampleImageCard key={sample} imageUrl={sample} alt="Sample Image" />
))}
</div>
)}
</div>
</div>
);
}

View File

@@ -20,7 +20,7 @@ export const TopBar: React.FC<Props> = ({ children, className }) => {
export const MainContent: React.FC<Props> = ({ children, className }) => {
return (
<div className={classNames('pt-16 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
<div className={classNames('pt-14 px-4 absolute top-0 left-0 w-full h-full overflow-auto', className)}>
{children ? children : null}
</div>
);

View File

@@ -7,7 +7,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
const [job, setJob] = useState<Job | null>(null);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshJobs = () => {
const refreshJob = () => {
setStatus('loading');
fetch(`/api/jobs?id=${jobID}`)
.then(res => res.json())
@@ -23,11 +23,11 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
};
useEffect(() => {
refreshJobs();
refreshJob();
if (reloadInterval) {
const interval = setInterval(() => {
refreshJobs();
refreshJob();
}, reloadInterval);
return () => {
@@ -36,5 +36,5 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
}
}, [jobID]);
return { job, setJob, status, refreshJobs };
return { job, setJob, status, refreshJob };
}

View File

@@ -0,0 +1,41 @@
'use client';
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
const [sampleImages, setSampleImages] = useState<string[]>([]);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error'>('idle');
const refreshSampleImages = () => {
setStatus('loading');
fetch(`/api/jobs/${jobID}/samples`)
.then(res => res.json())
.then(data => {
if (data.samples) {
setSampleImages(data.samples);
}
setStatus('success');
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshSampleImages();
if (reloadInterval) {
const interval = setInterval(() => {
refreshSampleImages();
}, reloadInterval);
return () => {
clearInterval(interval);
};
}
}, [jobID]);
return { sampleImages, setSampleImages, status, refreshSampleImages };
}

50
ui/src/server/settings.ts Normal file
View File

@@ -0,0 +1,50 @@
import { PrismaClient } from '@prisma/client';
import { defaultDatasetsFolder } from '@/paths';
import { defaultTrainFolder } from '@/paths';
import NodeCache from 'node-cache';
const myCache = new NodeCache();
const prisma = new PrismaClient();
export const flushCache = () => {
myCache.flushAll();
};
export const getDatasetsRoot = async () => {
const key = 'DATASETS_FOLDER';
let datasetsPath = myCache.get(key) as string;
if (datasetsPath) {
return datasetsPath;
}
let row = await prisma.settings.findFirst({
where: {
key: 'DATASETS_FOLDER',
},
});
datasetsPath = defaultDatasetsFolder;
if (row?.value && row.value !== '') {
datasetsPath = row.value;
}
myCache.set(key, datasetsPath);
return datasetsPath as string;
};
export const getTrainingFolder = async () => {
const key = 'TRAINING_FOLDER';
let trainingRoot = myCache.get(key) as string;
if (trainingRoot) {
return trainingRoot;
}
let row = await prisma.settings.findFirst({
where: {
key: key,
},
});
trainingRoot = defaultTrainFolder;
if (row?.value && row.value !== '') {
trainingRoot = row.value;
}
myCache.set(key, trainingRoot);
return trainingRoot as string;
};