mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-25 23:03:57 +00:00
Added checkpoint downloader
This commit is contained in:
105
ui/src/app/api/files/[...filePath]/route.ts
Normal file
105
ui/src/app/api/files/[...filePath]/route.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import fs from 'fs';
|
||||
import path from 'path';
|
||||
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { filePath: string } }) {
|
||||
const { filePath } = await params;
|
||||
try {
|
||||
// Decode the path
|
||||
const decodedFilePath = decodeURIComponent(filePath);
|
||||
|
||||
// Get allowed directories
|
||||
const datasetRoot = await getDatasetsRoot();
|
||||
const trainingRoot = await getTrainingFolder();
|
||||
const allowedDirs = [datasetRoot, trainingRoot];
|
||||
|
||||
// Security check: Ensure path is in allowed directory
|
||||
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
||||
|
||||
if (!isAllowed) {
|
||||
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
||||
return new NextResponse('Access denied', { status: 403 });
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if (!fs.existsSync(decodedFilePath)) {
|
||||
console.warn(`File not found: ${decodedFilePath}`);
|
||||
return new NextResponse('File not found', { status: 404 });
|
||||
}
|
||||
|
||||
// Get file info
|
||||
const stat = fs.statSync(decodedFilePath);
|
||||
if (!stat.isFile()) {
|
||||
return new NextResponse('Not a file', { status: 400 });
|
||||
}
|
||||
|
||||
// Get filename for Content-Disposition
|
||||
const filename = path.basename(decodedFilePath);
|
||||
|
||||
// Determine content type
|
||||
const ext = path.extname(decodedFilePath).toLowerCase();
|
||||
const contentTypeMap: { [key: string]: string } = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.bmp': 'image/bmp',
|
||||
'.safetensors': 'application/octet-stream',
|
||||
};
|
||||
|
||||
const contentType = contentTypeMap[ext] || 'application/octet-stream';
|
||||
|
||||
// Get range header for partial content support
|
||||
const range = request.headers.get('range');
|
||||
|
||||
// Common headers for better download handling
|
||||
const commonHeaders = {
|
||||
'Content-Type': contentType,
|
||||
'Accept-Ranges': 'bytes',
|
||||
'Cache-Control': 'public, max-age=86400',
|
||||
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
||||
'X-Content-Type-Options': 'nosniff'
|
||||
};
|
||||
|
||||
if (range) {
|
||||
// Parse range header
|
||||
const parts = range.replace(/bytes=/, '').split('-');
|
||||
const start = parseInt(parts[0], 10);
|
||||
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
||||
const chunkSize = (end - start) + 1;
|
||||
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
start,
|
||||
end,
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
status: 206,
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
||||
'Content-Length': String(chunkSize)
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// For full file download, read directly without streaming wrapper
|
||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||
highWaterMark: 64 * 1024 // 64KB buffer
|
||||
});
|
||||
|
||||
return new NextResponse(fileStream as any, {
|
||||
headers: {
|
||||
...commonHeaders,
|
||||
'Content-Length': String(stat.size)
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error serving file:', error);
|
||||
return new NextResponse('Internal Server Error', { status: 500 });
|
||||
}
|
||||
}
|
||||
48
ui/src/app/api/jobs/[jobID]/files/route.ts
Normal file
48
ui/src/app/api/jobs/[jobID]/files/route.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
import path from 'path';
|
||||
import fs from 'fs';
|
||||
import { getTrainingFolder } from '@/server/settings';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) {
|
||||
const { jobID } = await params;
|
||||
|
||||
const job = await prisma.job.findUnique({
|
||||
where: { id: jobID },
|
||||
});
|
||||
|
||||
if (!job) {
|
||||
return NextResponse.json({ error: 'Job not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
const trainingFolder = await getTrainingFolder();
|
||||
const jobFolder = path.join(trainingFolder, job.name);
|
||||
|
||||
if (!fs.existsSync(jobFolder)) {
|
||||
return NextResponse.json({ files: [] });
|
||||
}
|
||||
|
||||
// find all img (png, jpg, jpeg) files in the samples folder
|
||||
let files = fs
|
||||
.readdirSync(jobFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.safetensors');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(jobFolder, file);
|
||||
})
|
||||
.sort();
|
||||
|
||||
// get the file size for each file
|
||||
const fileObjects = files.map(file => {
|
||||
const stats = fs.statSync(file);
|
||||
return {
|
||||
path: file,
|
||||
size: stats.size,
|
||||
};
|
||||
});
|
||||
|
||||
return NextResponse.json({ files: fileObjects });
|
||||
}
|
||||
90
ui/src/components/FilesWidget.tsx
Normal file
90
ui/src/components/FilesWidget.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
import React from 'react';
|
||||
import useFilesList from '@/hooks/useFilesList';
|
||||
import Link from 'next/link';
|
||||
import { Loader2, AlertCircle, Download, Box, Brain } from 'lucide-react';
|
||||
|
||||
export default function FilesWidget({ jobID }: { jobID: string }) {
|
||||
const { files, status, refreshFiles } = useFilesList(jobID, 5000);
|
||||
|
||||
const cleanSize = (size: number) => {
|
||||
if (size < 1024) {
|
||||
return `${size} B`;
|
||||
} else if (size < 1024 * 1024) {
|
||||
return `${(size / 1024).toFixed(1)} KB`;
|
||||
} else if (size < 1024 * 1024 * 1024) {
|
||||
return `${(size / (1024 * 1024)).toFixed(1)} MB`;
|
||||
} else {
|
||||
return `${(size / (1024 * 1024 * 1024)).toFixed(1)} GB`;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden hover:shadow-2xl transition-all duration-300 border border-gray-800">
|
||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<Brain className="w-5 h-5 text-purple-400" />
|
||||
<h2 className="font-semibold text-gray-100">Model Checkpoints</h2>
|
||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
||||
{files.length}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="p-2">
|
||||
{status === 'loading' && (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<Loader2 className="w-5 h-5 text-gray-400 animate-spin" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{status === 'error' && (
|
||||
<div className="flex items-center justify-center py-4 text-rose-400 space-x-2">
|
||||
<AlertCircle className="w-4 h-4" />
|
||||
<span className="text-sm">Error loading checkpoints</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{['success', 'refreshing'].includes(status) && (
|
||||
<div className="space-y-1">
|
||||
{files.map((file, index) => {
|
||||
const fileName = file.path.split('/').pop() || '';
|
||||
const nameWithoutExt = fileName.replace('.safetensors', '');
|
||||
return (
|
||||
<a
|
||||
key={index}
|
||||
target='_blank'
|
||||
href={`/api/files/${encodeURIComponent(file.path)}`}
|
||||
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
|
||||
>
|
||||
<div className="flex items-center space-x-2 min-w-0">
|
||||
<Box className="w-4 h-4 text-purple-400 flex-shrink-0" />
|
||||
<div className="flex flex-col min-w-0">
|
||||
<div className="flex text-sm text-gray-200">
|
||||
<span className="overflow-hidden text-ellipsis direction-rtl whitespace-nowrap">
|
||||
{nameWithoutExt}
|
||||
</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-500">.safetensors</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-3 flex-shrink-0">
|
||||
<span className="text-xs text-gray-400">{cleanSize(file.size)}</span>
|
||||
<div className="bg-purple-500 bg-opacity-0 group-hover:bg-opacity-10 rounded-full p-1 transition-all">
|
||||
<Download className="w-3 h-3 text-purple-400" />
|
||||
</div>
|
||||
</div>
|
||||
</a>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{['success', 'refreshing'].includes(status) && files.length === 0 && (
|
||||
<div className="text-center py-4 text-gray-400 text-sm">
|
||||
No checkpoints available
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Job } from '@prisma/client';
|
||||
import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import GPUWidget from '@/components/GPUWidget';
|
||||
import FilesWidget from '@/components/FilesWidget';
|
||||
import { getJobConfig, getTotalSteps } from '@/utils/jobs';
|
||||
import { Cpu, HardDrive, Info } from 'lucide-react';
|
||||
import { useMemo } from 'react';
|
||||
@@ -92,7 +93,12 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
||||
</div>
|
||||
|
||||
{/* GPU Widget Panel */}
|
||||
<div className="col-span-1">{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
||||
<div className="col-span-1">
|
||||
<div>{isGPUInfoLoaded && gpuList.length > 0 && <GPUWidget gpu={gpuList[0]} />}</div>
|
||||
<div className='mt-4'>
|
||||
<FilesWidget jobID={job.id} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -118,6 +118,7 @@ export default function SampleImageModal() {
|
||||
const maxIdx = stepMinIdx + imageModal.numSamples - 1;
|
||||
const nextIdx = currentIdx + 1;
|
||||
if (nextIdx > maxIdx) return;
|
||||
if (nextIdx >= imageModal.sampleImages.length) return;
|
||||
openSampleImage({
|
||||
imgPath: imageModal.sampleImages[nextIdx],
|
||||
numSamples: imageModal.numSamples,
|
||||
|
||||
52
ui/src/hooks/useFilesList.tsx
Normal file
52
ui/src/hooks/useFilesList.tsx
Normal file
@@ -0,0 +1,52 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState, useRef } from 'react';
|
||||
|
||||
interface FileObject {
|
||||
path: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
export default function useFilesList(jobID: string, reloadInterval: null | number = null) {
|
||||
const [files, setFiles] = useState<FileObject[]>([]);
|
||||
const didInitialLoadRef = useRef(false);
|
||||
const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle');
|
||||
|
||||
const refreshFiles = () => {
|
||||
let loadStatus: 'loading' | 'refreshing' = 'loading';
|
||||
if (didInitialLoadRef.current) {
|
||||
loadStatus = 'refreshing';
|
||||
}
|
||||
setStatus(loadStatus);
|
||||
fetch(`/api/jobs/${jobID}/files`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
console.log('Fetched files:', data);
|
||||
if (data.files) {
|
||||
setFiles(data.files);
|
||||
}
|
||||
setStatus('success');
|
||||
didInitialLoadRef.current = true;
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching datasets:', error);
|
||||
setStatus('error');
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
refreshFiles();
|
||||
|
||||
if (reloadInterval) {
|
||||
const interval = setInterval(() => {
|
||||
refreshFiles();
|
||||
}, reloadInterval);
|
||||
|
||||
return () => {
|
||||
clearInterval(interval);
|
||||
};
|
||||
}
|
||||
}, [jobID]);
|
||||
|
||||
return { files, setFiles, status, refreshFiles };
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Job } from '@prisma/client';
|
||||
|
||||
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
||||
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
||||
|
||||
Reference in New Issue
Block a user