diff --git a/ui/src/app/api/files/[...filePath]/route.ts b/ui/src/app/api/files/[...filePath]/route.ts new file mode 100644 index 00000000..9c9b9398 --- /dev/null +++ b/ui/src/app/api/files/[...filePath]/route.ts @@ -0,0 +1,105 @@ +import { NextRequest, NextResponse } from 'next/server'; +import fs from 'fs'; +import path from 'path'; +import { getDatasetsRoot, getTrainingFolder } from '@/server/settings'; + +export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) { + const { filePath } = await params; + try { + // Decode the path + const decodedFilePath = decodeURIComponent(filePath); + + // Get allowed directories + const datasetRoot = await getDatasetsRoot(); + const trainingRoot = await getTrainingFolder(); + const allowedDirs = [datasetRoot, trainingRoot]; + + // Security check: Ensure path is in allowed directory + const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..'); + + if (!isAllowed) { + console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`); + return new NextResponse('Access denied', { status: 403 }); + } + + // Check if file exists + if (!fs.existsSync(decodedFilePath)) { + console.warn(`File not found: ${decodedFilePath}`); + return new NextResponse('File not found', { status: 404 }); + } + + // Get file info + const stat = fs.statSync(decodedFilePath); + if (!stat.isFile()) { + return new NextResponse('Not a file', { status: 400 }); + } + + // Get filename for Content-Disposition + const filename = path.basename(decodedFilePath); + + // Determine content type + const ext = path.extname(decodedFilePath).toLowerCase(); + const contentTypeMap: { [key: string]: string } = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.svg': 'image/svg+xml', + '.bmp': 'image/bmp', + '.safetensors': 'application/octet-stream', + }; + + const contentType = contentTypeMap[ext] || 'application/octet-stream'; + + // Get range header for partial content support + const range = request.headers.get('range'); + + // Common headers for better download handling + const commonHeaders = { + 'Content-Type': contentType, + 'Accept-Ranges': 'bytes', + 'Cache-Control': 'public, max-age=86400', + 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`, + 'X-Content-Type-Options': 'nosniff' + }; + + if (range) { + // Parse range header + const parts = range.replace(/bytes=/, '').split('-'); + const start = parseInt(parts[0], 10); + const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks + const chunkSize = (end - start) + 1; + + const fileStream = fs.createReadStream(decodedFilePath, { + start, + end, + highWaterMark: 64 * 1024 // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + status: 206, + headers: { + ...commonHeaders, + 'Content-Range': `bytes ${start}-${end}/${stat.size}`, + 'Content-Length': String(chunkSize) + }, + }); + } else { + // For full file download, read directly without streaming wrapper + const fileStream = fs.createReadStream(decodedFilePath, { + highWaterMark: 64 * 1024 // 64KB buffer + }); + + return new NextResponse(fileStream as any, { + headers: { + ...commonHeaders, + 'Content-Length': String(stat.size) + }, + }); + } + } catch (error) { + console.error('Error serving file:', error); + return new NextResponse('Internal Server Error', { status: 500 }); + } +} \ No newline at end of file diff --git a/ui/src/app/api/jobs/[jobID]/files/route.ts b/ui/src/app/api/jobs/[jobID]/files/route.ts new file mode 100644 index 00000000..f75fe6ce --- /dev/null +++ b/ui/src/app/api/jobs/[jobID]/files/route.ts @@ -0,0 +1,48 @@ +import { NextRequest, NextResponse } from 'next/server'; +import { PrismaClient } from '@prisma/client'; +import path from 'path'; +import fs from 'fs'; +import { getTrainingFolder } from '@/server/settings'; + +const prisma = new PrismaClient(); + +export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { + const { jobID } = await params; + + const job = await prisma.job.findUnique({ + where: { id: jobID }, + }); + + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const trainingFolder = await getTrainingFolder(); + const jobFolder = path.join(trainingFolder, job.name); + + if (!fs.existsSync(jobFolder)) { + return NextResponse.json({ files: [] }); + } + + // find all img (png, jpg, jpeg) files in the samples folder + let files = fs + .readdirSync(jobFolder) + .filter(file => { + return file.endsWith('.safetensors'); + }) + .map(file => { + return path.join(jobFolder, file); + }) + .sort(); + + // get the file size for each file + const fileObjects = files.map(file => { + const stats = fs.statSync(file); + return { + path: file, + size: stats.size, + }; + }); + + return NextResponse.json({ files: fileObjects }); +} diff --git a/ui/src/components/FilesWidget.tsx b/ui/src/components/FilesWidget.tsx new file mode 100644 index 00000000..7bfe755e --- /dev/null +++ b/ui/src/components/FilesWidget.tsx @@ -0,0 +1,90 @@ +import React from 'react'; +import useFilesList from '@/hooks/useFilesList'; +import Link from 'next/link'; +import { Loader2, AlertCircle, Download, Box, Brain } from 'lucide-react'; + +export default function FilesWidget({ jobID }: { jobID: string }) { + const { files, status, refreshFiles } = useFilesList(jobID, 5000); + + const cleanSize = (size: number) => { + if (size < 1024) { + return `${size} B`; + } else if (size < 1024 * 1024) { + return `${(size / 1024).toFixed(1)} KB`; + } else if (size < 1024 * 1024 * 1024) { + return `${(size / (1024 * 1024)).toFixed(1)} MB`; + } else { + return `${(size / (1024 * 1024 * 1024)).toFixed(1)} GB`; + } + }; + + return ( +
+
+
+ +

Model Checkpoints

+ + {files.length} + +
+
+ +
+ {status === 'loading' && ( +
+ +
+ )} + + {status === 'error' && ( +
+ + Error loading checkpoints +
+ )} + + {['success', 'refreshing'].includes(status) && ( +
+ {files.map((file, index) => { + const fileName = file.path.split('/').pop() || ''; + const nameWithoutExt = fileName.replace('.safetensors', ''); + return ( + +
+ +
+
+ + {nameWithoutExt} + +
+ .safetensors +
+
+
+ {cleanSize(file.size)} +
+ +
+
+
+ ); + })} +
+ )} + + {['success', 'refreshing'].includes(status) && files.length === 0 && ( +
+ No checkpoints available +
+ )} +
+
+ ); +} \ No newline at end of file diff --git a/ui/src/components/JobOverview.tsx b/ui/src/components/JobOverview.tsx index 92502491..278b06a6 100644 --- a/ui/src/components/JobOverview.tsx +++ b/ui/src/components/JobOverview.tsx @@ -1,6 +1,7 @@ import { Job } from '@prisma/client'; import useGPUInfo from '@/hooks/useGPUInfo'; import GPUWidget from '@/components/GPUWidget'; +import FilesWidget from '@/components/FilesWidget'; import { getJobConfig, getTotalSteps } from '@/utils/jobs'; import { Cpu, HardDrive, Info } from 'lucide-react'; import { useMemo } from 'react'; @@ -92,7 +93,12 @@ export default function JobOverview({ job }: JobOverviewProps) { {/* GPU Widget Panel */} -
{isGPUInfoLoaded && gpuList.length > 0 && }
+
+
{isGPUInfoLoaded && gpuList.length > 0 && }
+
+ +
+
); } diff --git a/ui/src/components/SampleImageModal.tsx b/ui/src/components/SampleImageModal.tsx index e492a3f1..a98734e0 100644 --- a/ui/src/components/SampleImageModal.tsx +++ b/ui/src/components/SampleImageModal.tsx @@ -118,6 +118,7 @@ export default function SampleImageModal() { const maxIdx = stepMinIdx + imageModal.numSamples - 1; const nextIdx = currentIdx + 1; if (nextIdx > maxIdx) return; + if (nextIdx >= imageModal.sampleImages.length) return; openSampleImage({ imgPath: imageModal.sampleImages[nextIdx], numSamples: imageModal.numSamples, diff --git a/ui/src/hooks/useFilesList.tsx b/ui/src/hooks/useFilesList.tsx new file mode 100644 index 00000000..7268a3cb --- /dev/null +++ b/ui/src/hooks/useFilesList.tsx @@ -0,0 +1,52 @@ +'use client'; + +import { useEffect, useState, useRef } from 'react'; + +interface FileObject { + path: string; + size: number; +} + +export default function useFilesList(jobID: string, reloadInterval: null | number = null) { + const [files, setFiles] = useState([]); + const didInitialLoadRef = useRef(false); + const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle'); + + const refreshFiles = () => { + let loadStatus: 'loading' | 'refreshing' = 'loading'; + if (didInitialLoadRef.current) { + loadStatus = 'refreshing'; + } + setStatus(loadStatus); + fetch(`/api/jobs/${jobID}/files`) + .then(res => res.json()) + .then(data => { + console.log('Fetched files:', data); + if (data.files) { + setFiles(data.files); + } + setStatus('success'); + didInitialLoadRef.current = true; + }) + .catch(error => { + console.error('Error fetching datasets:', error); + setStatus('error'); + }); + }; + + useEffect(() => { + refreshFiles(); + + if (reloadInterval) { + const interval = setInterval(() => { + refreshFiles(); + }, reloadInterval); + + return () => { + clearInterval(interval); + }; + } + }, [jobID]); + + return { files, setFiles, status, refreshFiles }; +} diff --git a/ui/src/hooks/useSampleImages.tsx b/ui/src/hooks/useSampleImages.tsx index 5407f72a..ccf07493 100644 --- a/ui/src/hooks/useSampleImages.tsx +++ b/ui/src/hooks/useSampleImages.tsx @@ -1,7 +1,6 @@ 'use client'; import { useEffect, useState } from 'react'; -import { Job } from '@prisma/client'; export default function useSampleImages(jobID: string, reloadInterval: null | number = null) { const [sampleImages, setSampleImages] = useState([]);