Added checkpoint downloader

This commit is contained in:
Jaret Burkett
2025-02-22 16:48:15 -07:00
parent 6e19e7449e
commit a280f78c69
7 changed files with 303 additions and 2 deletions

View File

@@ -0,0 +1,105 @@
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
const { filePath } = await params;
try {
// Decode the path
const decodedFilePath = decodeURIComponent(filePath);
// Get allowed directories
const datasetRoot = await getDatasetsRoot();
const trainingRoot = await getTrainingFolder();
const allowedDirs = [datasetRoot, trainingRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
if (!isAllowed) {
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
return new NextResponse('Access denied', { status: 403 });
}
// Check if file exists
if (!fs.existsSync(decodedFilePath)) {
console.warn(`File not found: ${decodedFilePath}`);
return new NextResponse('File not found', { status: 404 });
}
// Get file info
const stat = fs.statSync(decodedFilePath);
if (!stat.isFile()) {
return new NextResponse('Not a file', { status: 400 });
}
// Get filename for Content-Disposition
const filename = path.basename(decodedFilePath);
// Determine content type
const ext = path.extname(decodedFilePath).toLowerCase();
const contentTypeMap: { [key: string]: string } = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp',
'.svg': 'image/svg+xml',
'.bmp': 'image/bmp',
'.safetensors': 'application/octet-stream',
};
const contentType = contentTypeMap[ext] || 'application/octet-stream';
// Get range header for partial content support
const range = request.headers.get('range');
// Common headers for better download handling
const commonHeaders = {
'Content-Type': contentType,
'Accept-Ranges': 'bytes',
'Cache-Control': 'public, max-age=86400',
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
'X-Content-Type-Options': 'nosniff'
};
if (range) {
// Parse range header
const parts = range.replace(/bytes=/, '').split('-');
const start = parseInt(parts[0], 10);
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
const chunkSize = (end - start) + 1;
const fileStream = fs.createReadStream(decodedFilePath, {
start,
end,
highWaterMark: 64 * 1024 // 64KB buffer
});
return new NextResponse(fileStream as any, {
status: 206,
headers: {
...commonHeaders,
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
'Content-Length': String(chunkSize)
},
});
} else {
// For full file download, read directly without streaming wrapper
const fileStream = fs.createReadStream(decodedFilePath, {
highWaterMark: 64 * 1024 // 64KB buffer
});
return new NextResponse(fileStream as any, {
headers: {
...commonHeaders,
'Content-Length': String(stat.size)
},
});
}
} catch (error) {
console.error('Error serving file:', error);
return new NextResponse('Internal Server Error', { status: 500 });
}
}

View File

@@ -0,0 +1,48 @@
import { NextRequest, NextResponse } from 'next/server';
import { PrismaClient } from '@prisma/client';
import path from 'path';
import fs from 'fs';
import { getTrainingFolder } from '@/server/settings';
const prisma = new PrismaClient();
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
const { jobID } = await params;
const job = await prisma.job.findUnique({
where: { id: jobID },
});
if (!job) {
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
}
const trainingFolder = await getTrainingFolder();
const jobFolder = path.join(trainingFolder, job.name);
if (!fs.existsSync(jobFolder)) {
return NextResponse.json({ files: [] });
}
// find all img (png, jpg, jpeg) files in the samples folder
let files = fs
.readdirSync(jobFolder)
.filter(file => {
return file.endsWith('.safetensors');
})
.map(file => {
return path.join(jobFolder, file);
})
.sort();
// get the file size for each file
const fileObjects = files.map(file => {
const stats = fs.statSync(file);
return {
path: file,
size: stats.size,
};
});
return NextResponse.json({ files: fileObjects });
}

View File

@@ -0,0 +1,90 @@
import React from 'react';
import useFilesList from '@/hooks/useFilesList';
import Link from 'next/link';
import { Loader2, AlertCircle, Download, Box, Brain } from 'lucide-react';
export default function FilesWidget({ jobID }: { jobID: string }) {
const { files, status, refreshFiles } = useFilesList(jobID, 5000);
const cleanSize = (size: number) => {
if (size < 1024) {
return `${size} B`;
} else if (size < 1024 * 1024) {
return `${(size / 1024).toFixed(1)} KB`;
} else if (size < 1024 * 1024 * 1024) {
return `${(size / (1024 * 1024)).toFixed(1)} MB`;
} else {
return `${(size / (1024 * 1024 * 1024)).toFixed(1)} GB`;
}
};
return (
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
<div className="flex items-center space-x-2">
<Brain className="w-5 h-5 text-purple-400" />
<h2 className="font-semibold text-gray-100">Model Checkpoints</h2>
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
{files.length}
</span>
</div>
</div>
<div className="p-2">
{status === 'loading' && (
<div className="flex items-center justify-center py-4">
<Loader2 className="w-5 h-5 text-gray-400 animate-spin" />
</div>
)}
{status === 'error' && (
<div className="flex items-center justify-center py-4 text-rose-400 space-x-2">
<AlertCircle className="w-4 h-4" />
<span className="text-sm">Error loading checkpoints</span>
</div>
)}
{['success', 'refreshing'].includes(status) && (
<div className="space-y-1">
{files.map((file, index) => {
const fileName = file.path.split('/').pop() || '';
const nameWithoutExt = fileName.replace('.safetensors', '');
return (
<a
key={index}
target='_blank'
href={`/api/files/${encodeURIComponent(file.path)}`}
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
>
<div className="flex items-center space-x-2 min-w-0">
<Box className="w-4 h-4 text-purple-400 flex-shrink-0" />
<div className="flex flex-col min-w-0">
<div className="flex text-sm text-gray-200">
<span className="overflow-hidden text-ellipsis direction-rtl whitespace-nowrap">
{nameWithoutExt}
</span>
</div>
<span className="text-xs text-gray-500">.safetensors</span>
</div>
</div>
<div className="flex items-center space-x-3 flex-shrink-0">
<span className="text-xs text-gray-400">{cleanSize(file.size)}</span>
<div className="bg-purple-500 bg-opacity-0 group-hover:bg-opacity-10 rounded-full p-1 transition-all">
<Download className="w-3 h-3 text-purple-400" />
</div>
</div>
</a>
);
})}
</div>
)}
{['success', 'refreshing'].includes(status) && files.length === 0 && (
<div className="text-center py-4 text-gray-400 text-sm">
No checkpoints available
</div>
)}
</div>
</div>
);
}

View File

@@ -1,6 +1,7 @@
import { Job } from '@prisma/client';
import useGPUInfo from '@/hooks/useGPUInfo';
import GPUWidget from '@/components/GPUWidget';
import FilesWidget from '@/components/FilesWidget';
import { getJobConfig, getTotalSteps } from '@/utils/jobs';
import { Cpu, HardDrive, Info } from 'lucide-react';
import { useMemo } from 'react';
@@ -92,7 +93,12 @@ export default function JobOverview({ job }: JobOverviewProps) {
</div>
{/* GPU Widget Panel */}
<div className="col-span-1">{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
<div className="col-span-1">
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
<div className='mt-4'>
<FilesWidget jobID={job.id} />
</div>
</div>
</div>
);
}

View File

@@ -118,6 +118,7 @@ export default function SampleImageModal() {
const maxIdx = stepMinIdx + imageModal.numSamples - 1;
const nextIdx = currentIdx + 1;
if (nextIdx > maxIdx) return;
if (nextIdx >= imageModal.sampleImages.length) return;
openSampleImage({
imgPath: imageModal.sampleImages[nextIdx],
numSamples: imageModal.numSamples,

View File

@@ -0,0 +1,52 @@
'use client';
import { useEffect, useState, useRef } from 'react';
interface FileObject {
path: string;
size: number;
}
export default function useFilesList(jobID: string, reloadInterval: null | number = null) {
const [files, setFiles] = useState<FileObject[]>([]);
const didInitialLoadRef = useRef(false);
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
const refreshFiles = () => {
let loadStatus: 'loading' | 'refreshing' = 'loading';
if (didInitialLoadRef.current) {
loadStatus = 'refreshing';
}
setStatus(loadStatus);
fetch(`/api/jobs/${jobID}/files`)
.then(res => res.json())
.then(data => {
console.log('Fetched files:', data);
if (data.files) {
setFiles(data.files);
}
setStatus('success');
didInitialLoadRef.current = true;
})
.catch(error => {
console.error('Error fetching datasets:', error);
setStatus('error');
});
};
useEffect(() => {
refreshFiles();
if (reloadInterval) {
const interval = setInterval(() => {
refreshFiles();
}, reloadInterval);
return () => {
clearInterval(interval);
};
}
}, [jobID]);
return { files, setFiles, status, refreshFiles };
}

View File

@@ -1,7 +1,6 @@
'use client';
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
const [sampleImages, setSampleImages] = useState<string[]>([]);