mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added a way to secure the UI. Plus various bug fixes and quality of life updates
This commit is contained in:
@@ -15,7 +15,8 @@ class UITrainer(SDTrainer):
|
|||||||
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
|
||||||
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
|
||||||
if not os.path.exists(self.sqlite_db_path):
|
if not os.path.exists(self.sqlite_db_path):
|
||||||
raise Exception(f"SQLite database not found at {self.sqlite_db_path}")
|
raise Exception(
|
||||||
|
f"SQLite database not found at {self.sqlite_db_path}")
|
||||||
print(f"Using SQLite database at {self.sqlite_db_path}")
|
print(f"Using SQLite database at {self.sqlite_db_path}")
|
||||||
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
self.job_id = os.environ.get("AITK_JOB_ID", None)
|
||||||
self.job_id = self.job_id.strip() if self.job_id is not None else None
|
self.job_id = self.job_id.strip() if self.job_id is not None else None
|
||||||
@@ -147,6 +148,8 @@ class UITrainer(SDTrainer):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*self._async_tasks)
|
await asyncio.gather(*self._async_tasks)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error waiting for async operations: {e}")
|
||||||
finally:
|
finally:
|
||||||
# Clear the task list after completion
|
# Clear the task list after completion
|
||||||
self._async_tasks.clear()
|
self._async_tasks.clear()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
"dev": "next dev --turbopack",
|
"dev": "next dev --turbopack",
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"start": "next start --port 8675",
|
"start": "next start --port 8675",
|
||||||
|
"build_and_start": "npm install && npm run update_db && npm run build && npm run start",
|
||||||
"lint": "next lint",
|
"lint": "next lint",
|
||||||
"update_db": "npx prisma generate && npx prisma db push",
|
"update_db": "npx prisma generate && npx prisma db push",
|
||||||
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
"format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\""
|
||||||
|
|||||||
6
ui/src/app/api/auth/route.ts
Normal file
6
ui/src/app/api/auth/route.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import { NextResponse } from 'next/server';
|
||||||
|
|
||||||
|
export async function GET() {
|
||||||
|
// if this gets hit, auth has already been verified
|
||||||
|
return NextResponse.json({ isAuthenticated: true });
|
||||||
|
}
|
||||||
@@ -19,4 +19,4 @@ export async function POST(request: Request) {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,31 +8,25 @@ export async function POST(request: Request) {
|
|||||||
const body = await request.json();
|
const body = await request.json();
|
||||||
const { datasetName } = body;
|
const { datasetName } = body;
|
||||||
const datasetFolder = path.join(datasetsPath, datasetName);
|
const datasetFolder = path.join(datasetsPath, datasetName);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Check if folder exists
|
// Check if folder exists
|
||||||
if (!fs.existsSync(datasetFolder)) {
|
if (!fs.existsSync(datasetFolder)) {
|
||||||
return NextResponse.json(
|
return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 });
|
||||||
{ error: `Folder '${datasetName}' not found` },
|
|
||||||
{ status: 404 }
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find all images recursively
|
// Find all images recursively
|
||||||
const imageFiles = findImagesRecursively(datasetFolder);
|
const imageFiles = findImagesRecursively(datasetFolder);
|
||||||
|
|
||||||
// Format response
|
// Format response
|
||||||
const result = imageFiles.map(imgPath => ({
|
const result = imageFiles.map(imgPath => ({
|
||||||
img_path: imgPath
|
img_path: imgPath,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return NextResponse.json({ images: result });
|
return NextResponse.json({ images: result });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error finding images:', error);
|
console.error('Error finding images:', error);
|
||||||
return NextResponse.json(
|
return NextResponse.json({ error: 'Failed to process request' }, { status: 500 });
|
||||||
{ error: 'Failed to process request' },
|
|
||||||
{ status: 500 }
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,13 +38,13 @@ export async function POST(request: Request) {
|
|||||||
function findImagesRecursively(dir: string): string[] {
|
function findImagesRecursively(dir: string): string[] {
|
||||||
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp'];
|
const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp'];
|
||||||
let results: string[] = [];
|
let results: string[] = [];
|
||||||
|
|
||||||
const items = fs.readdirSync(dir);
|
const items = fs.readdirSync(dir);
|
||||||
|
|
||||||
for (const item of items) {
|
for (const item of items) {
|
||||||
const itemPath = path.join(dir, item);
|
const itemPath = path.join(dir, item);
|
||||||
const stat = fs.statSync(itemPath);
|
const stat = fs.statSync(itemPath);
|
||||||
|
|
||||||
if (stat.isDirectory()) {
|
if (stat.isDirectory()) {
|
||||||
// If it's a directory, recursively search it
|
// If it's a directory, recursively search it
|
||||||
results = results.concat(findImagesRecursively(itemPath));
|
results = results.concat(findImagesRecursively(itemPath));
|
||||||
@@ -62,6 +56,6 @@ function findImagesRecursively(dir: string): string[] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
|||||||
const allowedDirs = [datasetRoot, trainingRoot];
|
const allowedDirs = [datasetRoot, trainingRoot];
|
||||||
|
|
||||||
// Security check: Ensure path is in allowed directory
|
// Security check: Ensure path is in allowed directory
|
||||||
const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
const isAllowed =
|
||||||
|
allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..');
|
||||||
|
|
||||||
if (!isAllowed) {
|
if (!isAllowed) {
|
||||||
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`);
|
||||||
@@ -62,7 +63,7 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
|||||||
'Accept-Ranges': 'bytes',
|
'Accept-Ranges': 'bytes',
|
||||||
'Cache-Control': 'public, max-age=86400',
|
'Cache-Control': 'public, max-age=86400',
|
||||||
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`,
|
||||||
'X-Content-Type-Options': 'nosniff'
|
'X-Content-Type-Options': 'nosniff',
|
||||||
};
|
};
|
||||||
|
|
||||||
if (range) {
|
if (range) {
|
||||||
@@ -70,12 +71,12 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
|||||||
const parts = range.replace(/bytes=/, '').split('-');
|
const parts = range.replace(/bytes=/, '').split('-');
|
||||||
const start = parseInt(parts[0], 10);
|
const start = parseInt(parts[0], 10);
|
||||||
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks
|
||||||
const chunkSize = (end - start) + 1;
|
const chunkSize = end - start + 1;
|
||||||
|
|
||||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
highWaterMark: 64 * 1024 // 64KB buffer
|
highWaterMark: 64 * 1024, // 64KB buffer
|
||||||
});
|
});
|
||||||
|
|
||||||
return new NextResponse(fileStream as any, {
|
return new NextResponse(fileStream as any, {
|
||||||
@@ -83,19 +84,19 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
|||||||
headers: {
|
headers: {
|
||||||
...commonHeaders,
|
...commonHeaders,
|
||||||
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
'Content-Range': `bytes ${start}-${end}/${stat.size}`,
|
||||||
'Content-Length': String(chunkSize)
|
'Content-Length': String(chunkSize),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// For full file download, read directly without streaming wrapper
|
// For full file download, read directly without streaming wrapper
|
||||||
const fileStream = fs.createReadStream(decodedFilePath, {
|
const fileStream = fs.createReadStream(decodedFilePath, {
|
||||||
highWaterMark: 64 * 1024 // 64KB buffer
|
highWaterMark: 64 * 1024, // 64KB buffer
|
||||||
});
|
});
|
||||||
|
|
||||||
return new NextResponse(fileStream as any, {
|
return new NextResponse(fileStream as any, {
|
||||||
headers: {
|
headers: {
|
||||||
...commonHeaders,
|
...commonHeaders,
|
||||||
'Content-Length': String(stat.size)
|
'Content-Length': String(stat.size),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -103,4 +104,4 @@ export async function GET(request: NextRequest, { params }: { params: { filePath
|
|||||||
console.error('Error serving file:', error);
|
console.error('Error serving file:', error);
|
||||||
return new NextResponse('Internal Server Error', { status: 500 });
|
return new NextResponse('Internal Server Error', { status: 500 });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ export async function GET() {
|
|||||||
// Get platform
|
// Get platform
|
||||||
const platform = os.platform();
|
const platform = os.platform();
|
||||||
const isWindows = platform === 'win32';
|
const isWindows = platform === 'win32';
|
||||||
|
|
||||||
// Check if nvidia-smi is available
|
// Check if nvidia-smi is available
|
||||||
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
const hasNvidiaSmi = await checkNvidiaSmi(isWindows);
|
||||||
|
|
||||||
@@ -61,8 +61,9 @@ async function checkNvidiaSmi(isWindows: boolean): Promise<boolean> {
|
|||||||
|
|
||||||
async function getGpuStats(isWindows: boolean) {
|
async function getGpuStats(isWindows: boolean) {
|
||||||
// Command is the same for both platforms, but the path might be different
|
// Command is the same for both platforms, but the path might be different
|
||||||
const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
const command =
|
||||||
|
'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits';
|
||||||
|
|
||||||
// Execute command
|
// Execute command
|
||||||
const { stdout } = await execAsync(command);
|
const { stdout } = await execAsync(command);
|
||||||
|
|
||||||
@@ -117,4 +118,4 @@ async function getGpuStats(isWindows: boolean) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return gpus;
|
return gpus;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ export async function POST(request: Request) {
|
|||||||
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
|
return NextResponse.json({ error: 'Image does not exist' }, { status: 404 });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// check for caption
|
// check for caption
|
||||||
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt';
|
||||||
// save caption to file
|
// save caption to file
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
|||||||
cwd: TOOLKIT_ROOT,
|
cwd: TOOLKIT_ROOT,
|
||||||
windowsHide: false,
|
windowsHide: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
subprocess.unref();
|
subprocess.unref();
|
||||||
} else {
|
} else {
|
||||||
// For non-Windows platforms, use your original approach
|
// For non-Windows platforms, use your original approach
|
||||||
@@ -116,7 +116,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
|||||||
},
|
},
|
||||||
cwd: TOOLKIT_ROOT,
|
cwd: TOOLKIT_ROOT,
|
||||||
});
|
});
|
||||||
|
|
||||||
subprocess.unref();
|
subprocess.unref();
|
||||||
}
|
}
|
||||||
// const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
// const subprocess = spawn(pythonPath, [runFilePath, configPath], {
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ const prisma = new PrismaClient();
|
|||||||
export async function GET(request: Request) {
|
export async function GET(request: Request) {
|
||||||
const { searchParams } = new URL(request.url);
|
const { searchParams } = new URL(request.url);
|
||||||
const id = searchParams.get('id');
|
const id = searchParams.get('id');
|
||||||
console.log('ID:', id);
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (id) {
|
if (id) {
|
||||||
@@ -19,7 +18,6 @@ export async function GET(request: Request) {
|
|||||||
const jobs = await prisma.job.findMany({
|
const jobs = await prisma.job.findMany({
|
||||||
orderBy: { created_at: 'desc' },
|
orderBy: { created_at: 'desc' },
|
||||||
});
|
});
|
||||||
console.log('Jobs:', jobs);
|
|
||||||
return NextResponse.json({ jobs: jobs });
|
return NextResponse.json({ jobs: jobs });
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { NextResponse } from 'next/server';
|
import { NextResponse } from 'next/server';
|
||||||
import { PrismaClient } from '@prisma/client';
|
import { PrismaClient } from '@prisma/client';
|
||||||
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths';
|
||||||
import {flushCache} from '@/server/settings';
|
import { flushCache } from '@/server/settings';
|
||||||
|
|
||||||
const prisma = new PrismaClient();
|
const prisma = new PrismaClient();
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import DatasetImageCard from '@/components/DatasetImageCard';
|
|||||||
import { Button } from '@headlessui/react';
|
import { Button } from '@headlessui/react';
|
||||||
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
|
import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal';
|
||||||
import { TopBar, MainContent } from '@/components/layout';
|
import { TopBar, MainContent } from '@/components/layout';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
|
export default function DatasetPage({ params }: { params: { datasetName: string } }) {
|
||||||
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
|
const [imgList, setImgList] = useState<{ img_path: string }[]>([]);
|
||||||
@@ -15,15 +16,11 @@ export default function DatasetPage({ params }: { params: { datasetName: string
|
|||||||
|
|
||||||
const refreshImageList = (dbName: string) => {
|
const refreshImageList = (dbName: string) => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
fetch('/api/datasets/listImages', {
|
console.log('Fetching images for dataset:', dbName);
|
||||||
method: 'POST',
|
apiClient
|
||||||
headers: {
|
.post('/api/datasets/listImages', { datasetName: dbName })
|
||||||
'Content-Type': 'application/json',
|
.then((res: any) => {
|
||||||
},
|
const data = res.data;
|
||||||
body: JSON.stringify({ datasetName: dbName }),
|
|
||||||
})
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
|
||||||
console.log('Images:', data.images);
|
console.log('Images:', data.images);
|
||||||
// sort
|
// sort
|
||||||
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
|
data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path));
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import { FaRegTrashAlt } from 'react-icons/fa';
|
|||||||
import { openConfirm } from '@/components/ConfirmModal';
|
import { openConfirm } from '@/components/ConfirmModal';
|
||||||
import { TopBar, MainContent } from '@/components/layout';
|
import { TopBar, MainContent } from '@/components/layout';
|
||||||
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
import UniversalTable, { TableColumn } from '@/components/UniversalTable';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function Datasets() {
|
export default function Datasets() {
|
||||||
const { datasets, status, refreshDatasets } = useDatasetList();
|
const { datasets, status, refreshDatasets } = useDatasetList();
|
||||||
@@ -54,16 +55,10 @@ export default function Datasets() {
|
|||||||
type: 'warning',
|
type: 'warning',
|
||||||
confirmText: 'Delete',
|
confirmText: 'Delete',
|
||||||
onConfirm: () => {
|
onConfirm: () => {
|
||||||
fetch('/api/datasets/delete', {
|
apiClient
|
||||||
method: 'POST',
|
.post('/api/datasets/delete', { name: datasetName })
|
||||||
headers: {
|
.then(() => {
|
||||||
'Content-Type': 'application/json',
|
console.log('Dataset deleted:', datasetName);
|
||||||
},
|
|
||||||
body: JSON.stringify({ name: datasetName }),
|
|
||||||
})
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
|
||||||
console.log('Dataset deleted:', data);
|
|
||||||
refreshDatasets();
|
refreshDatasets();
|
||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
@@ -76,14 +71,7 @@ export default function Datasets() {
|
|||||||
const handleCreateDataset = async (e: React.FormEvent) => {
|
const handleCreateDataset = async (e: React.FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/datasets/create', {
|
const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ name: newDatasetName }),
|
|
||||||
});
|
|
||||||
const data = await response.json();
|
|
||||||
console.log('New dataset created:', data);
|
console.log('New dataset created:', data);
|
||||||
refreshDatasets();
|
refreshDatasets();
|
||||||
setNewDatasetName('');
|
setNewDatasetName('');
|
||||||
|
|||||||
@@ -33,11 +33,7 @@ const yamlConfig: YAML.DocumentOptions &
|
|||||||
directives: true,
|
directives: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function AdvancedJob({
|
export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
|
||||||
jobConfig,
|
|
||||||
setJobConfig,
|
|
||||||
settings,
|
|
||||||
}: Props) {
|
|
||||||
const [editorValue, setEditorValue] = useState<string>('');
|
const [editorValue, setEditorValue] = useState<string>('');
|
||||||
const lastJobConfigUpdateStringRef = useRef('');
|
const lastJobConfigUpdateStringRef = useRef('');
|
||||||
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
|
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
linear: 32,
|
linear: 32,
|
||||||
linear_alpha: 32,
|
linear_alpha: 32,
|
||||||
lokr_full_rank: true,
|
lokr_full_rank: true,
|
||||||
lokr_factor: -1
|
lokr_factor: -1,
|
||||||
},
|
},
|
||||||
save: {
|
save: {
|
||||||
dtype: 'bf16',
|
dtype: 'bf16',
|
||||||
@@ -39,9 +39,7 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
save_format: 'diffusers',
|
save_format: 'diffusers',
|
||||||
push_to_hub: false,
|
push_to_hub: false,
|
||||||
},
|
},
|
||||||
datasets: [
|
datasets: [defaultDatasetConfig],
|
||||||
defaultDatasetConfig
|
|
||||||
],
|
|
||||||
train: {
|
train: {
|
||||||
batch_size: 1,
|
batch_size: 1,
|
||||||
bypass_guidance_embedding: true,
|
bypass_guidance_embedding: true,
|
||||||
@@ -55,7 +53,7 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
timestep_type: 'sigmoid',
|
timestep_type: 'sigmoid',
|
||||||
content_or_style: 'balanced',
|
content_or_style: 'balanced',
|
||||||
optimizer_params: {
|
optimizer_params: {
|
||||||
weight_decay: 1e-4
|
weight_decay: 1e-4,
|
||||||
},
|
},
|
||||||
unload_text_encoder: false,
|
unload_text_encoder: false,
|
||||||
lr: 0.0001,
|
lr: 0.0001,
|
||||||
@@ -66,8 +64,7 @@ export const defaultJobConfig: JobConfig = {
|
|||||||
dtype: 'bf16',
|
dtype: 'bf16',
|
||||||
diff_output_preservation: false,
|
diff_output_preservation: false,
|
||||||
diff_output_preservation_multiplier: 1.0,
|
diff_output_preservation_multiplier: 1.0,
|
||||||
diff_output_preservation_class: 'person'
|
diff_output_preservation_class: 'person',
|
||||||
|
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
name_or_path: 'ostris/Flex.1-alpha',
|
name_or_path: 'ostris/Flex.1-alpha',
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ export const modelArchs = [
|
|||||||
{ name: 'flux', label: 'Flux.1' },
|
{ name: 'flux', label: 'Flux.1' },
|
||||||
{ name: 'wan21', label: 'Wan 2.1' },
|
{ name: 'wan21', label: 'Wan 2.1' },
|
||||||
{ name: 'lumina2', label: 'Lumina2' },
|
{ name: 'lumina2', label: 'Lumina2' },
|
||||||
]
|
];
|
||||||
|
|
||||||
export const isVideoModelFromArch = (arch: string) => {
|
export const isVideoModelFromArch = (arch: string) => {
|
||||||
const videoArches = ['wan21'];
|
const videoArches = ['wan21'];
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import { Button } from '@headlessui/react';
|
|||||||
import { FaChevronLeft } from 'react-icons/fa';
|
import { FaChevronLeft } from 'react-icons/fa';
|
||||||
import SimpleJob from './SimpleJob';
|
import SimpleJob from './SimpleJob';
|
||||||
import AdvancedJob from './AdvancedJob';
|
import AdvancedJob from './AdvancedJob';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
const isDev = process.env.NODE_ENV === 'development';
|
const isDev = process.env.NODE_ENV === 'development';
|
||||||
|
|
||||||
@@ -56,12 +57,13 @@ export default function TrainingForm() {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (runId) {
|
if (runId) {
|
||||||
fetch(`/api/jobs?id=${runId}`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs?id=${runId}`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
|
console.log('Training:', data);
|
||||||
setGpuIDs(data.gpu_ids);
|
setGpuIDs(data.gpu_ids);
|
||||||
setJobConfig(JSON.parse(data.job_config));
|
setJobConfig(JSON.parse(data.job_config));
|
||||||
// setJobConfig(data.name, 'config.name');
|
|
||||||
})
|
})
|
||||||
.catch(error => console.error('Error fetching training:', error));
|
.catch(error => console.error('Error fetching training:', error));
|
||||||
}
|
}
|
||||||
@@ -85,33 +87,30 @@ export default function TrainingForm() {
|
|||||||
if (status === 'saving') return;
|
if (status === 'saving') return;
|
||||||
setStatus('saving');
|
setStatus('saving');
|
||||||
|
|
||||||
try {
|
apiClient
|
||||||
const response = await fetch('/api/jobs', {
|
.post('/api/jobs', {
|
||||||
method: 'POST',
|
id: runId,
|
||||||
headers: {
|
name: jobConfig.config.name,
|
||||||
'Content-Type': 'application/json',
|
gpu_ids: gpuIDs,
|
||||||
},
|
job_config: jobConfig,
|
||||||
body: JSON.stringify({
|
})
|
||||||
id: runId,
|
.then(res => {
|
||||||
name: jobConfig.config.name,
|
setStatus('success');
|
||||||
gpu_ids: gpuIDs,
|
if (runId) {
|
||||||
job_config: jobConfig,
|
router.push(`/jobs/${runId}`);
|
||||||
}),
|
} else {
|
||||||
});
|
router.push(`/jobs/${res.data.id}`);
|
||||||
|
}
|
||||||
if (!response.ok) throw new Error('Failed to save training');
|
})
|
||||||
|
.catch(error => {
|
||||||
setStatus('success');
|
console.error('Error saving training:', error);
|
||||||
if (!runId) {
|
setStatus('error');
|
||||||
const data = await response.json();
|
})
|
||||||
router.push(`/jobs/${data.id}`);
|
.finally(() =>
|
||||||
}
|
setTimeout(() => {
|
||||||
setTimeout(() => setStatus('idle'), 2000);
|
setStatus('idle');
|
||||||
} catch (error) {
|
}, 2000),
|
||||||
console.error('Error saving training:', error);
|
);
|
||||||
setStatus('error');
|
|
||||||
setTimeout(() => setStatus('idle'), 2000);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSubmit = async (e: React.FormEvent) => {
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
|
|||||||
@@ -13,10 +13,7 @@ export default function Dashboard() {
|
|||||||
</div>
|
</div>
|
||||||
<div className="flex-1"></div>
|
<div className="flex-1"></div>
|
||||||
<div>
|
<div>
|
||||||
<Link
|
<Link href="/jobs/new" className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md">
|
||||||
href="/jobs/new"
|
|
||||||
className="text-gray-200 bg-slate-600 px-3 py-1 rounded-md"
|
|
||||||
>
|
|
||||||
New Training Job
|
New Training Job
|
||||||
</Link>
|
</Link>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { ThemeProvider } from '@/components/ThemeProvider';
|
|||||||
import ConfirmModal from '@/components/ConfirmModal';
|
import ConfirmModal from '@/components/ConfirmModal';
|
||||||
import SampleImageModal from '@/components/SampleImageModal';
|
import SampleImageModal from '@/components/SampleImageModal';
|
||||||
import { Suspense } from 'react';
|
import { Suspense } from 'react';
|
||||||
|
import AuthWrapper from '@/components/AuthWrapper';
|
||||||
|
|
||||||
const inter = Inter({ subsets: ['latin'] });
|
const inter = Inter({ subsets: ['latin'] });
|
||||||
|
|
||||||
@@ -15,6 +16,9 @@ export const metadata: Metadata = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
||||||
|
// Check if the AI_TOOLKIT_AUTH environment variable is set
|
||||||
|
const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<html lang="en" className="dark">
|
<html lang="en" className="dark">
|
||||||
<head>
|
<head>
|
||||||
@@ -22,13 +26,14 @@ export default function RootLayout({ children }: { children: React.ReactNode })
|
|||||||
</head>
|
</head>
|
||||||
<body className={inter.className}>
|
<body className={inter.className}>
|
||||||
<ThemeProvider>
|
<ThemeProvider>
|
||||||
<div className="flex h-screen bg-gray-950">
|
<AuthWrapper authRequired={authRequired}>
|
||||||
<Sidebar />
|
<div className="flex h-screen bg-gray-950">
|
||||||
|
<Sidebar />
|
||||||
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
<main className="flex-1 overflow-auto bg-gray-950 text-gray-100 relative">
|
||||||
<Suspense>{children}</Suspense>
|
<Suspense>{children}</Suspense>
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
|
</AuthWrapper>
|
||||||
</ThemeProvider>
|
</ThemeProvider>
|
||||||
<ConfirmModal />
|
<ConfirmModal />
|
||||||
<SampleImageModal />
|
<SampleImageModal />
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ import { redirect } from 'next/navigation';
|
|||||||
|
|
||||||
export default function Home() {
|
export default function Home() {
|
||||||
redirect('/dashboard');
|
redirect('/dashboard');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import useSettings from '@/hooks/useSettings';
|
import useSettings from '@/hooks/useSettings';
|
||||||
import { TopBar, MainContent } from '@/components/layout';
|
import { TopBar, MainContent } from '@/components/layout';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function Settings() {
|
export default function Settings() {
|
||||||
const { settings, setSettings } = useSettings();
|
const { settings, setSettings } = useSettings();
|
||||||
@@ -12,24 +13,18 @@ export default function Settings() {
|
|||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setStatus('saving');
|
setStatus('saving');
|
||||||
|
|
||||||
try {
|
apiClient
|
||||||
const response = await fetch('/api/settings', {
|
.post('/api/settings', settings)
|
||||||
method: 'POST',
|
.then(() => {
|
||||||
headers: {
|
setStatus('success');
|
||||||
'Content-Type': 'application/json',
|
})
|
||||||
},
|
.catch(error => {
|
||||||
body: JSON.stringify(settings),
|
console.error('Error saving settings:', error);
|
||||||
|
setStatus('error');
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
setTimeout(() => setStatus('idle'), 2000);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) throw new Error('Failed to save settings');
|
|
||||||
|
|
||||||
setStatus('success');
|
|
||||||
setTimeout(() => setStatus('idle'), 2000);
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error saving settings:', error);
|
|
||||||
setStatus('error');
|
|
||||||
setTimeout(() => setStatus('idle'), 2000);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/re
|
|||||||
import { FaUpload } from 'react-icons/fa';
|
import { FaUpload } from 'react-icons/fa';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useDropzone } from 'react-dropzone';
|
import { useDropzone } from 'react-dropzone';
|
||||||
import axios from 'axios';
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export interface AddImagesModalState {
|
export interface AddImagesModalState {
|
||||||
datasetName: string;
|
datasetName: string;
|
||||||
@@ -15,7 +15,7 @@ export const addImagesModalState = createGlobalState<AddImagesModalState | null>
|
|||||||
|
|
||||||
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
|
export const openImagesModal = (datasetName: string, onComplete: () => void) => {
|
||||||
addImagesModalState.set({ datasetName, onComplete });
|
addImagesModalState.set({ datasetName, onComplete });
|
||||||
}
|
};
|
||||||
|
|
||||||
export default function AddImagesModal() {
|
export default function AddImagesModal() {
|
||||||
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
|
const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use();
|
||||||
@@ -36,46 +36,49 @@ export default function AddImagesModal() {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const onDrop = useCallback(async (acceptedFiles: File[]) => {
|
const onDrop = useCallback(
|
||||||
if (acceptedFiles.length === 0) return;
|
async (acceptedFiles: File[]) => {
|
||||||
|
if (acceptedFiles.length === 0) return;
|
||||||
|
|
||||||
setIsUploading(true);
|
setIsUploading(true);
|
||||||
setUploadProgress(0);
|
|
||||||
|
|
||||||
const formData = new FormData();
|
|
||||||
acceptedFiles.forEach(file => {
|
|
||||||
formData.append('files', file);
|
|
||||||
});
|
|
||||||
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
|
||||||
|
|
||||||
try {
|
|
||||||
await axios.post(`/api/datasets/upload`, formData, {
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'multipart/form-data',
|
|
||||||
},
|
|
||||||
onUploadProgress: (progressEvent) => {
|
|
||||||
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
|
||||||
setUploadProgress(percentCompleted);
|
|
||||||
},
|
|
||||||
timeout: 0, // Disable timeout
|
|
||||||
});
|
|
||||||
|
|
||||||
onDone();
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Upload failed:', error);
|
|
||||||
} finally {
|
|
||||||
setIsUploading(false);
|
|
||||||
setUploadProgress(0);
|
setUploadProgress(0);
|
||||||
}
|
|
||||||
}, [addImagesModalInfo]);
|
const formData = new FormData();
|
||||||
|
acceptedFiles.forEach(file => {
|
||||||
|
formData.append('files', file);
|
||||||
|
});
|
||||||
|
formData.append('datasetName', addImagesModalInfo?.datasetName || '');
|
||||||
|
|
||||||
|
try {
|
||||||
|
await apiClient.post(`/api/datasets/upload`, formData, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'multipart/form-data',
|
||||||
|
},
|
||||||
|
onUploadProgress: progressEvent => {
|
||||||
|
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
|
||||||
|
setUploadProgress(percentCompleted);
|
||||||
|
},
|
||||||
|
timeout: 0, // Disable timeout
|
||||||
|
});
|
||||||
|
|
||||||
|
onDone();
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Upload failed:', error);
|
||||||
|
} finally {
|
||||||
|
setIsUploading(false);
|
||||||
|
setUploadProgress(0);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[addImagesModalInfo],
|
||||||
|
);
|
||||||
|
|
||||||
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
const { getRootProps, getInputProps, isDragActive } = useDropzone({
|
||||||
onDrop,
|
onDrop,
|
||||||
accept: {
|
accept: {
|
||||||
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
|
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'],
|
||||||
'text/*': ['.txt']
|
'text/*': ['.txt'],
|
||||||
},
|
},
|
||||||
multiple: true
|
multiple: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -105,22 +108,15 @@ export default function AddImagesModal() {
|
|||||||
<input {...getInputProps()} />
|
<input {...getInputProps()} />
|
||||||
<FaUpload className="size-8 mb-3 text-gray-400" />
|
<FaUpload className="size-8 mb-3 text-gray-400" />
|
||||||
<p className="text-sm text-gray-200 text-center">
|
<p className="text-sm text-gray-200 text-center">
|
||||||
{isDragActive
|
{isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}
|
||||||
? 'Drop the files here...'
|
|
||||||
: 'Drag & drop files here, or click to select files'}
|
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
{isUploading && (
|
{isUploading && (
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
<div className="w-full bg-gray-700 rounded-full h-2.5">
|
||||||
<div
|
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
|
||||||
className="bg-blue-600 h-2.5 rounded-full"
|
|
||||||
style={{ width: `${uploadProgress}%` }}
|
|
||||||
></div>
|
|
||||||
</div>
|
</div>
|
||||||
<p className="text-sm text-gray-300 mt-2 text-center">
|
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
|
||||||
Uploading... {uploadProgress}%
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -152,4 +148,4 @@ export default function AddImagesModal() {
|
|||||||
</div>
|
</div>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
163
ui/src/components/AuthWrapper.tsx
Normal file
163
ui/src/components/AuthWrapper.tsx
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
'use client';
|
||||||
|
|
||||||
|
import { useState, useEffect, useRef } from 'react';
|
||||||
|
import { apiClient, isAuthorizedState } from '@/utils/api';
|
||||||
|
import { createGlobalState } from 'react-global-hooks';
|
||||||
|
|
||||||
|
interface AuthWrapperProps {
|
||||||
|
authRequired: boolean;
|
||||||
|
children: React.ReactNode | React.ReactNode[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) {
|
||||||
|
const [token, setToken] = useState('');
|
||||||
|
// start with true, and deauth if needed
|
||||||
|
const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use();
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const [error, setError] = useState('');
|
||||||
|
const [isBrowser, setIsBrowser] = useState(false);
|
||||||
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
const isAuthorized = authRequired ? isAuthorizedGlobal : true;
|
||||||
|
|
||||||
|
// Set isBrowser to true when component mounts
|
||||||
|
useEffect(() => {
|
||||||
|
setIsBrowser(true);
|
||||||
|
// Get token from localStorage only after component has mounted
|
||||||
|
const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
||||||
|
setToken(storedToken);
|
||||||
|
checkAuth();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// auto focus on input when not authorized
|
||||||
|
useEffect(() => {
|
||||||
|
if (isAuthorized) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setTimeout(() => {
|
||||||
|
if (inputRef.current) {
|
||||||
|
inputRef.current.focus();
|
||||||
|
}
|
||||||
|
}, 100);
|
||||||
|
}, [isAuthorized]);
|
||||||
|
|
||||||
|
const checkAuth = async () => {
|
||||||
|
// always get current stored token here to avoid state race conditions
|
||||||
|
const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || '';
|
||||||
|
if (!authRequired || isLoading || currentToken === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setIsLoading(true);
|
||||||
|
setError('');
|
||||||
|
try {
|
||||||
|
const response = await apiClient.get('/api/auth');
|
||||||
|
if (response.data.isAuthenticated) {
|
||||||
|
setIsAuthorized(true);
|
||||||
|
} else {
|
||||||
|
setIsAuthorized(false);
|
||||||
|
setError('Invalid token. Please try again.');
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
setIsAuthorized(false);
|
||||||
|
console.log(err);
|
||||||
|
setError('Invalid token. Please try again.');
|
||||||
|
}
|
||||||
|
setIsLoading(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setError('');
|
||||||
|
|
||||||
|
if (!token.trim()) {
|
||||||
|
setError('Please enter your token');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isBrowser) {
|
||||||
|
localStorage.setItem('AI_TOOLKIT_AUTH', token);
|
||||||
|
checkAuth();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isAuthorized) {
|
||||||
|
return <>{children}</>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex min-h-screen bg-gray-900 text-gray-100 absolute top-0 left-0 right-0 bottom-0 scroll-auto">
|
||||||
|
{/* Left side - decorative or brand area */}
|
||||||
|
<div className="hidden lg:flex lg:w-1/2 bg-gray-800 flex-col justify-center items-center p-12">
|
||||||
|
<div className="mb-4">
|
||||||
|
{/* Replace with your own logo */}
|
||||||
|
<div className="flex items-center justify-center">
|
||||||
|
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<h1 className="text-4xl mb-6">AI Toolkit</h1>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Right side - login form */}
|
||||||
|
<div className="w-full lg:w-1/2 flex flex-col justify-center items-center p-8 sm:p-12">
|
||||||
|
<div className="w-full max-w-md">
|
||||||
|
<div className="lg:hidden flex justify-center mb-4">
|
||||||
|
{/* Mobile logo */}
|
||||||
|
<div className="flex items-center justify-center">
|
||||||
|
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-24 inline" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<h2 className="text-3xl text-center mb-2 lg:hidden">AI Toolkit</h2>
|
||||||
|
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-6">
|
||||||
|
<div>
|
||||||
|
<label htmlFor="token" className="block text-sm font-medium text-gray-400 mb-2">
|
||||||
|
Access Token
|
||||||
|
</label>
|
||||||
|
<input
|
||||||
|
id="token"
|
||||||
|
name="token"
|
||||||
|
type="password"
|
||||||
|
autoComplete="off"
|
||||||
|
required
|
||||||
|
value={token}
|
||||||
|
ref={inputRef}
|
||||||
|
onChange={e => setToken(e.target.value)}
|
||||||
|
className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200"
|
||||||
|
placeholder="Enter your token"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="p-3 bg-red-900/50 border border-red-800 rounded-lg text-red-200 text-sm">{error}</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<button
|
||||||
|
type="submit"
|
||||||
|
disabled={isLoading}
|
||||||
|
className="w-full py-3 px-4 bg-blue-600 hover:bg-blue-700 rounded-lg text-white font-medium focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 transition duration-200 flex items-center justify-center"
|
||||||
|
>
|
||||||
|
{isLoading ? (
|
||||||
|
<svg
|
||||||
|
className="animate-spin h-5 w-5 text-white"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
fill="none"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
||||||
|
<path
|
||||||
|
className="opacity-75"
|
||||||
|
fill="currentColor"
|
||||||
|
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
|
||||||
|
></path>
|
||||||
|
</svg>
|
||||||
|
) : (
|
||||||
|
'Check Token'
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -12,4 +12,4 @@ const Card: React.FC<CardProps> = ({ title, children }) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default Card;
|
export default Card;
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 're
|
|||||||
import { FaTrashAlt } from 'react-icons/fa';
|
import { FaTrashAlt } from 'react-icons/fa';
|
||||||
import { openConfirm } from './ConfirmModal';
|
import { openConfirm } from './ConfirmModal';
|
||||||
import classNames from 'classnames';
|
import classNames from 'classnames';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
interface DatasetImageCardProps {
|
interface DatasetImageCardProps {
|
||||||
imageUrl: string;
|
imageUrl: string;
|
||||||
@@ -27,30 +28,32 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
|
|||||||
const isGettingCaption = useRef<boolean>(false);
|
const isGettingCaption = useRef<boolean>(false);
|
||||||
|
|
||||||
const fetchCaption = async () => {
|
const fetchCaption = async () => {
|
||||||
try {
|
if (isGettingCaption.current || isCaptionLoaded) return;
|
||||||
if (isGettingCaption.current || isCaptionLoaded) return;
|
isGettingCaption.current = true;
|
||||||
isGettingCaption.current = true;
|
apiClient
|
||||||
const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`);
|
.get(`/api/caption/${encodeURIComponent(imageUrl)}`)
|
||||||
const data = await response.text();
|
.then(res => res.data)
|
||||||
setCaption(data);
|
.then(data => {
|
||||||
setSavedCaption(data);
|
console.log('Caption fetched:', data);
|
||||||
setIsCaptionLoaded(true);
|
|
||||||
} catch (error) {
|
setCaption(data || '');
|
||||||
console.error('Error fetching caption:', error);
|
setSavedCaption(data || '');
|
||||||
}
|
setIsCaptionLoaded(true);
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.error('Error fetching caption:', error);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
isGettingCaption.current = false;
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const saveCaption = () => {
|
const saveCaption = () => {
|
||||||
const trimmedCaption = caption.trim();
|
const trimmedCaption = caption.trim();
|
||||||
if (trimmedCaption === savedCaption) return;
|
if (trimmedCaption === savedCaption) return;
|
||||||
fetch('/api/img/caption', {
|
apiClient
|
||||||
method: 'POST',
|
.post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption })
|
||||||
headers: {
|
.then(res => res.data)
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }),
|
|
||||||
})
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Caption saved:', data);
|
console.log('Caption saved:', data);
|
||||||
setSavedCaption(trimmedCaption);
|
setSavedCaption(trimmedCaption);
|
||||||
@@ -129,16 +132,10 @@ const DatasetImageCard: React.FC<DatasetImageCardProps> = ({
|
|||||||
type: 'warning',
|
type: 'warning',
|
||||||
confirmText: 'Delete',
|
confirmText: 'Delete',
|
||||||
onConfirm: () => {
|
onConfirm: () => {
|
||||||
fetch('/api/img/delete', {
|
apiClient
|
||||||
method: 'POST',
|
.post('/api/img/delete', { imgPath: imageUrl })
|
||||||
headers: {
|
.then(() => {
|
||||||
'Content-Type': 'application/json',
|
console.log('Image deleted:', imageUrl);
|
||||||
},
|
|
||||||
body: JSON.stringify({ imgPath: imageUrl }),
|
|
||||||
})
|
|
||||||
.then(res => res.json())
|
|
||||||
.then(data => {
|
|
||||||
console.log('Image deleted:', data);
|
|
||||||
onDelete();
|
onDelete();
|
||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
|||||||
<div className="flex items-center space-x-2">
|
<div className="flex items-center space-x-2">
|
||||||
<Brain className="w-5 h-5 text-purple-400" />
|
<Brain className="w-5 h-5 text-purple-400" />
|
||||||
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
|
<h2 className="font-semibold text-gray-100">Checkpoints</h2>
|
||||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">{files.length}</span>
|
||||||
{files.length}
|
|
||||||
</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -50,9 +48,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
|||||||
const fileName = file.path.split('/').pop() || '';
|
const fileName = file.path.split('/').pop() || '';
|
||||||
const nameWithoutExt = fileName.replace('.safetensors', '');
|
const nameWithoutExt = fileName.replace('.safetensors', '');
|
||||||
return (
|
return (
|
||||||
<a
|
<a
|
||||||
key={index}
|
key={index}
|
||||||
target='_blank'
|
target="_blank"
|
||||||
href={`/api/files/${encodeURIComponent(file.path)}`}
|
href={`/api/files/${encodeURIComponent(file.path)}`}
|
||||||
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
|
className="group flex items-center justify-between px-2 py-1.5 rounded-lg hover:bg-gray-800 transition-all duration-200"
|
||||||
>
|
>
|
||||||
@@ -80,11 +78,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) {
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
{['success', 'refreshing'].includes(status) && files.length === 0 && (
|
{['success', 'refreshing'].includes(status) && files.length === 0 && (
|
||||||
<div className="text-center py-4 text-gray-400 text-sm">
|
<div className="text-center py-4 text-gray-400 text-sm">No checkpoints available</div>
|
||||||
No checkpoints available
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import React, { useState, useEffect, useRef } from 'react';
|
import React, { useState, useEffect, useRef, useMemo } from 'react';
|
||||||
import { GPUApiResponse } from '@/types';
|
import { GPUApiResponse } from '@/types';
|
||||||
import Loading from '@/components/Loading';
|
import Loading from '@/components/Loading';
|
||||||
import GPUWidget from '@/components/GPUWidget';
|
import GPUWidget from '@/components/GPUWidget';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
const GpuMonitor: React.FC = () => {
|
const GpuMonitor: React.FC = () => {
|
||||||
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
|
const [gpuData, setGpuData] = useState<GPUApiResponse | null>(null);
|
||||||
@@ -12,27 +13,26 @@ const GpuMonitor: React.FC = () => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchGpuInfo = async () => {
|
const fetchGpuInfo = async () => {
|
||||||
try {
|
if (isFetchingGpuRef.current) {
|
||||||
if (isFetchingGpuRef.current) {
|
return;
|
||||||
return;
|
|
||||||
}
|
|
||||||
isFetchingGpuRef.current = true;
|
|
||||||
const response = await fetch('/api/gpu');
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const data: GPUApiResponse = await response.json();
|
|
||||||
setGpuData(data);
|
|
||||||
setLastUpdated(new Date());
|
|
||||||
setError(null);
|
|
||||||
} catch (err) {
|
|
||||||
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
|
||||||
} finally {
|
|
||||||
isFetchingGpuRef.current = false;
|
|
||||||
setLoading(false);
|
|
||||||
}
|
}
|
||||||
|
setLoading(true);
|
||||||
|
isFetchingGpuRef.current = true;
|
||||||
|
apiClient
|
||||||
|
.get('/api/gpu')
|
||||||
|
.then(res => res.data)
|
||||||
|
.then(data => {
|
||||||
|
setGpuData(data);
|
||||||
|
setLastUpdated(new Date());
|
||||||
|
setError(null);
|
||||||
|
})
|
||||||
|
.catch(err => {
|
||||||
|
setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
isFetchingGpuRef.current = false;
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
// Fetch immediately on component mount
|
// Fetch immediately on component mount
|
||||||
@@ -69,46 +69,63 @@ const GpuMonitor: React.FC = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (loading) {
|
console.log('state', {
|
||||||
return <Loading />;
|
loading,
|
||||||
}
|
gpuData,
|
||||||
|
error,
|
||||||
|
lastUpdated,
|
||||||
|
});
|
||||||
|
|
||||||
|
const content = useMemo(() => {
|
||||||
|
if (loading && !gpuData) {
|
||||||
|
return <Loading />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<div className="bg-red-900 border border-red-600 text-red-200 px-4 py-3 rounded relative" role="alert">
|
||||||
|
<strong className="font-bold">Error!</strong>
|
||||||
|
<span className="block sm:inline"> {error}</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gpuData) {
|
||||||
|
return (
|
||||||
|
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||||
|
<span className="block sm:inline">No GPU data available.</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!gpuData.hasNvidiaSmi) {
|
||||||
|
return (
|
||||||
|
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||||
|
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
||||||
|
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
||||||
|
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gpuData.gpus.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="bg-yellow-900 border border-yellow-700 text-yellow-300 px-4 py-3 rounded relative" role="alert">
|
||||||
|
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const gridClass = getGridClasses(gpuData?.gpus?.length || 1);
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return (
|
return (
|
||||||
<div className="bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded relative" role="alert">
|
<div className={`grid ${gridClass} gap-3`}>
|
||||||
<strong className="font-bold">Error!</strong>
|
{gpuData.gpus.map((gpu, idx) => (
|
||||||
<span className="block sm:inline"> {error}</span>
|
<GPUWidget key={idx} gpu={gpu} />
|
||||||
|
))}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}, [loading, gpuData, error]);
|
||||||
|
|
||||||
if (!gpuData) {
|
|
||||||
return (
|
|
||||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
|
||||||
<span className="block sm:inline">No GPU data available.</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!gpuData.hasNvidiaSmi) {
|
|
||||||
return (
|
|
||||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
|
||||||
<strong className="font-bold">No NVIDIA GPUs detected!</strong>
|
|
||||||
<span className="block sm:inline"> nvidia-smi is not available on this system.</span>
|
|
||||||
{gpuData.error && <p className="mt-2 text-sm">{gpuData.error}</p>}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gpuData.gpus.length === 0) {
|
|
||||||
return (
|
|
||||||
<div className="bg-yellow-100 border border-yellow-400 text-yellow-700 px-4 py-3 rounded relative" role="alert">
|
|
||||||
<span className="block sm:inline">No GPUs found, but nvidia-smi is available.</span>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const gridClass = getGridClasses(gpuData.gpus.length);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
@@ -116,12 +133,7 @@ const GpuMonitor: React.FC = () => {
|
|||||||
<h1 className="text-md">GPU Monitor</h1>
|
<h1 className="text-md">GPU Monitor</h1>
|
||||||
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
<div className="text-xs text-gray-500">Last updated: {lastUpdated?.toLocaleTimeString()}</div>
|
||||||
</div>
|
</div>
|
||||||
|
{content}
|
||||||
<div className={`grid ${gridClass} gap-3`}>
|
|
||||||
{gpuData.gpus.map((gpu, idx) => (
|
|
||||||
<GPUWidget key={idx} gpu={gpu} />
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
|||||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||||
<div className="flex items-center space-x-2">
|
<div className="flex items-center space-x-2">
|
||||||
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
|
<h2 className="font-semibold text-gray-100">{gpu.name}</h2>
|
||||||
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">
|
<span className="px-2 py-0.5 bg-gray-700 rounded-full text-xs text-gray-300">#{gpu.index}</span>
|
||||||
#{gpu.index}
|
|
||||||
</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -38,18 +36,14 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
|||||||
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
|
<Thermometer className={`w-4 h-4 ${getTemperatureColor(gpu.temperature)}`} />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-xs text-gray-400">Temperature</p>
|
<p className="text-xs text-gray-400">Temperature</p>
|
||||||
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>
|
<p className={`text-sm font-medium ${getTemperatureColor(gpu.temperature)}`}>{gpu.temperature}°C</p>
|
||||||
{gpu.temperature}°C
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center space-x-2">
|
<div className="flex items-center space-x-2">
|
||||||
<Fan className="w-4 h-4 text-blue-400" />
|
<Fan className="w-4 h-4 text-blue-400" />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-xs text-gray-400">Fan Speed</p>
|
<p className="text-xs text-gray-400">Fan Speed</p>
|
||||||
<p className="text-sm font-medium text-blue-400">
|
<p className="text-sm font-medium text-blue-400">{gpu.fan.speed}%</p>
|
||||||
{gpu.fan.speed}%
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -99,7 +93,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
|||||||
<p className="text-xs text-gray-400">Power Draw</p>
|
<p className="text-xs text-gray-400">Power Draw</p>
|
||||||
<p className="text-sm text-gray-200">
|
<p className="text-sm text-gray-200">
|
||||||
{gpu.power.draw?.toFixed(1)}W
|
{gpu.power.draw?.toFixed(1)}W
|
||||||
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || " ? "}W</span>
|
<span className="text-gray-400 text-xs"> / {gpu.power.limit?.toFixed(1) || ' ? '}W</span>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -107,4 +101,4 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
|||||||
{/* Job Information Panel */}
|
{/* Job Information Panel */}
|
||||||
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
|
<div className="col-span-2 bg-gray-900 rounded-xl shadow-lg overflow-hidden border border-gray-800">
|
||||||
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
<div className="bg-gray-800 px-4 py-3 flex items-center justify-between">
|
||||||
<h2 className="text-gray-100"><Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}</h2>
|
<h2 className="text-gray-100">
|
||||||
|
<Info className="w-5 h-5 mr-2 -mt-1 text-amber-400 inline-block" /> {job.info}
|
||||||
|
</h2>
|
||||||
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
|
<span className={`px-3 py-1 rounded-full text-sm ${getStatusColor(job.status)}`}>{job.status}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -85,7 +87,7 @@ export default function JobOverview({ job }: JobOverviewProps) {
|
|||||||
<Gauge className="w-5 h-5 text-green-400" />
|
<Gauge className="w-5 h-5 text-green-400" />
|
||||||
<div>
|
<div>
|
||||||
<p className="text-xs text-gray-400">Speed</p>
|
<p className="text-xs text-gray-400">Speed</p>
|
||||||
<p className="text-sm font-medium text-gray-200">{job.speed_string == "" ? "?" : job.speed_string}</p>
|
<p className="text-sm font-medium text-gray-200">{job.speed_string == '' ? '?' : job.speed_string}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
export default function Loading() {
|
export default function Loading() {
|
||||||
return (
|
return (
|
||||||
<div className="flex justify-center items-center h-64">
|
<div className="flex justify-center items-center h-42">
|
||||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
|
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-blue-500"></div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -11,7 +11,14 @@ interface SampleImageCardProps {
|
|||||||
onDelete?: () => void;
|
onDelete?: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SampleImageCard: React.FC<SampleImageCardProps> = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => {
|
const SampleImageCard: React.FC<SampleImageCardProps> = ({
|
||||||
|
imageUrl,
|
||||||
|
alt,
|
||||||
|
numSamples,
|
||||||
|
sampleImages,
|
||||||
|
children,
|
||||||
|
className = '',
|
||||||
|
}) => {
|
||||||
const cardRef = useRef<HTMLDivElement>(null);
|
const cardRef = useRef<HTMLDivElement>(null);
|
||||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||||
const [loaded, setLoaded] = useState<boolean>(false);
|
const [loaded, setLoaded] = useState<boolean>(false);
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import Link from 'next/link';
|
import Link from 'next/link';
|
||||||
import { Home, Settings, BrainCircuit, Images } from 'lucide-react';
|
import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react';
|
||||||
|
|
||||||
const Sidebar = () => {
|
const Sidebar = () => {
|
||||||
const navigation = [
|
const navigation = [
|
||||||
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
||||||
|
{ name: 'New Job', href: '/jobs/new', icon: Plus },
|
||||||
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
|
{ name: 'Training Jobs', href: '/jobs', icon: BrainCircuit },
|
||||||
{ name: 'Datasets', href: '/datasets', icon: Images },
|
{ name: 'Datasets', href: '/datasets', icon: Images },
|
||||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||||
@@ -33,7 +34,7 @@ const Sidebar = () => {
|
|||||||
</ul>
|
</ul>
|
||||||
</nav>
|
</nav>
|
||||||
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
|
<a href="https://patreon.com/ostris" target="_blank" rel="noreferrer" className="flex items-center space-x-2 p-4">
|
||||||
<div className='min-w-[26px] min-h-[26px]'>
|
<div className="min-w-[26px] min-h-[26px]">
|
||||||
<svg
|
<svg
|
||||||
viewBox="0 0 512 512"
|
viewBox="0 0 512 512"
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ export interface NumberInputProps extends InputProps {
|
|||||||
|
|
||||||
export const NumberInput = (props: NumberInputProps) => {
|
export const NumberInput = (props: NumberInputProps) => {
|
||||||
const { label, value, onChange, placeholder, required, min, max } = props;
|
const { label, value, onChange, placeholder, required, min, max } = props;
|
||||||
|
|
||||||
// Add controlled internal state to properly handle partial inputs
|
// Add controlled internal state to properly handle partial inputs
|
||||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
|||||||
value={inputValue}
|
value={inputValue}
|
||||||
onChange={e => {
|
onChange={e => {
|
||||||
const rawValue = e.target.value;
|
const rawValue = e.target.value;
|
||||||
|
|
||||||
// Update the input display with the raw value
|
// Update the input display with the raw value
|
||||||
setInputValue(rawValue);
|
setInputValue(rawValue);
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
|||||||
// Only apply constraints and call onChange when we have a valid number
|
// Only apply constraints and call onChange when we have a valid number
|
||||||
if (!isNaN(numValue)) {
|
if (!isNaN(numValue)) {
|
||||||
let constrainedValue = numValue;
|
let constrainedValue = numValue;
|
||||||
|
|
||||||
// Apply min/max constraints if they exist
|
// Apply min/max constraints if they exist
|
||||||
if (min !== undefined && constrainedValue < min) {
|
if (min !== undefined && constrainedValue < min) {
|
||||||
constrainedValue = min;
|
constrainedValue = min;
|
||||||
@@ -89,7 +89,7 @@ export const NumberInput = (props: NumberInputProps) => {
|
|||||||
if (max !== undefined && constrainedValue > max) {
|
if (max !== undefined && constrainedValue > max) {
|
||||||
constrainedValue = max;
|
constrainedValue = max;
|
||||||
}
|
}
|
||||||
|
|
||||||
onChange(constrainedValue);
|
onChange(constrainedValue);
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
@@ -152,14 +152,14 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
className={classNames(
|
className={classNames(
|
||||||
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
|
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
|
||||||
checked ? 'bg-blue-600' : 'bg-gray-700',
|
checked ? 'bg-blue-600' : 'bg-gray-700',
|
||||||
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
|
disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80',
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<span className="sr-only">Toggle {label}</span>
|
<span className="sr-only">Toggle {label}</span>
|
||||||
<span
|
<span
|
||||||
className={classNames(
|
className={classNames(
|
||||||
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
|
||||||
checked ? 'translate-x-5' : 'translate-x-0'
|
checked ? 'translate-x-5' : 'translate-x-0',
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
</button>
|
</button>
|
||||||
@@ -168,7 +168,7 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
htmlFor={id}
|
htmlFor={id}
|
||||||
className={classNames(
|
className={classNames(
|
||||||
'text-sm font-medium cursor-pointer select-none',
|
'text-sm font-medium cursor-pointer select-none',
|
||||||
disabled ? 'text-gray-500' : 'text-gray-300'
|
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function useDatasetList() {
|
export default function useDatasetList() {
|
||||||
const [datasets, setDatasets] = useState<string[]>([]);
|
const [datasets, setDatasets] = useState<string[]>([]);
|
||||||
@@ -8,8 +9,9 @@ export default function useDatasetList() {
|
|||||||
|
|
||||||
const refreshDatasets = () => {
|
const refreshDatasets = () => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
fetch('/api/datasets/list')
|
apiClient
|
||||||
.then(res => res.json())
|
.get('/api/datasets/list')
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Datasets:', data);
|
console.log('Datasets:', data);
|
||||||
// sort
|
// sort
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useEffect, useState, useRef } from 'react';
|
import { useEffect, useState, useRef } from 'react';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
interface FileObject {
|
interface FileObject {
|
||||||
path: string;
|
path: string;
|
||||||
@@ -18,8 +19,9 @@ export default function useFilesList(jobID: string, reloadInterval: null | numbe
|
|||||||
loadStatus = 'refreshing';
|
loadStatus = 'refreshing';
|
||||||
}
|
}
|
||||||
setStatus(loadStatus);
|
setStatus(loadStatus);
|
||||||
fetch(`/api/jobs/${jobID}/files`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs/${jobID}/files`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Fetched files:', data);
|
console.log('Fetched files:', data);
|
||||||
if (data.files) {
|
if (data.files) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import { GPUApiResponse, GpuInfo } from '@/types';
|
import { GPUApiResponse, GpuInfo } from '@/types';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
|
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
|
||||||
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
|
const [gpuList, setGpuList] = useState<GpuInfo[]>([]);
|
||||||
@@ -11,18 +12,11 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
|
|||||||
const fetchGpuInfo = async () => {
|
const fetchGpuInfo = async () => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/gpu');
|
const data: GPUApiResponse = await apiClient.get('/api/gpu').then(res => res.data);
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const data: GPUApiResponse = await response.json();
|
|
||||||
let gpus = data.gpus.sort((a, b) => a.index - b.index);
|
let gpus = data.gpus.sort((a, b) => a.index - b.index);
|
||||||
if (gpuIds) {
|
if (gpuIds) {
|
||||||
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
|
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
|
||||||
}
|
}
|
||||||
|
|
||||||
setGpuList(gpus);
|
setGpuList(gpus);
|
||||||
setStatus('success');
|
setStatus('success');
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
@@ -51,4 +45,4 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
|
|||||||
}, [gpuIds, reloadInterval]); // Added dependencies
|
}, [gpuIds, reloadInterval]); // Added dependencies
|
||||||
|
|
||||||
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
|
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { Job } from '@prisma/client';
|
import { Job } from '@prisma/client';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function useJob(jobID: string, reloadInterval: null | number = null) {
|
export default function useJob(jobID: string, reloadInterval: null | number = null) {
|
||||||
const [job, setJob] = useState<Job | null>(null);
|
const [job, setJob] = useState<Job | null>(null);
|
||||||
@@ -9,8 +10,9 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
|||||||
|
|
||||||
const refreshJob = () => {
|
const refreshJob = () => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
fetch(`/api/jobs?id=${jobID}`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs?id=${jobID}`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Job:', data);
|
console.log('Job:', data);
|
||||||
setJob(data);
|
setJob(data);
|
||||||
@@ -32,7 +34,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
|
|||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
clearInterval(interval);
|
clearInterval(interval);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}, [jobID]);
|
}, [jobID]);
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { Job } from '@prisma/client';
|
import { Job } from '@prisma/client';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function useJobsList(onlyActive = false) {
|
export default function useJobsList(onlyActive = false) {
|
||||||
const [jobs, setJobs] = useState<Job[]>([]);
|
const [jobs, setJobs] = useState<Job[]>([]);
|
||||||
@@ -9,8 +10,9 @@ export default function useJobsList(onlyActive = false) {
|
|||||||
|
|
||||||
const refreshJobs = () => {
|
const refreshJobs = () => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
fetch('/api/jobs')
|
apiClient
|
||||||
.then(res => res.json())
|
.get('/api/jobs')
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Jobs:', data);
|
console.log('Jobs:', data);
|
||||||
if (data.error) {
|
if (data.error) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
|
||||||
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
const [sampleImages, setSampleImages] = useState<string[]>([]);
|
||||||
@@ -8,9 +9,11 @@ export default function useSampleImages(jobID: string, reloadInterval: null | nu
|
|||||||
|
|
||||||
const refreshSampleImages = () => {
|
const refreshSampleImages = () => {
|
||||||
setStatus('loading');
|
setStatus('loading');
|
||||||
fetch(`/api/jobs/${jobID}/samples`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs/${jobID}/samples`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
|
console.log('Fetched sample images:', data);
|
||||||
if (data.samples) {
|
if (data.samples) {
|
||||||
setSampleImages(data.samples);
|
setSampleImages(data.samples);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client';
|
'use client';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export interface Settings {
|
export interface Settings {
|
||||||
HF_TOKEN: string;
|
HF_TOKEN: string;
|
||||||
@@ -16,10 +17,11 @@ export default function useSettings() {
|
|||||||
});
|
});
|
||||||
const [isSettingsLoaded, setIsLoaded] = useState(false);
|
const [isSettingsLoaded, setIsLoaded] = useState(false);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Fetch current settings
|
apiClient
|
||||||
fetch('/api/settings')
|
.get('/api/settings')
|
||||||
.then(res => res.json())
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
|
console.log('Settings:', data);
|
||||||
setSettings({
|
setSettings({
|
||||||
HF_TOKEN: data.HF_TOKEN || '',
|
HF_TOKEN: data.HF_TOKEN || '',
|
||||||
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
|
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
|
||||||
|
|||||||
49
ui/src/middleware.ts
Normal file
49
ui/src/middleware.ts
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// middleware.ts (at the root of your project)
|
||||||
|
import { NextResponse } from 'next/server';
|
||||||
|
import type { NextRequest } from 'next/server';
|
||||||
|
|
||||||
|
// if route starts with these, approve
|
||||||
|
const publicRoutes = ['/api/img/', '/api/files/'];
|
||||||
|
|
||||||
|
export function middleware(request: NextRequest) {
|
||||||
|
// check env var for AI_TOOLKIT_AUTH, if not set, approve all requests
|
||||||
|
// if it is set make sure bearer token matches
|
||||||
|
const tokenToUse = process.env.AI_TOOLKIT_AUTH || null;
|
||||||
|
if (!tokenToUse) {
|
||||||
|
return NextResponse.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the token from the headers
|
||||||
|
const token = request.headers.get('Authorization')?.split(' ')[1];
|
||||||
|
|
||||||
|
// allow public routes to pass through
|
||||||
|
if (publicRoutes.some(route => request.nextUrl.pathname.startsWith(route))) {
|
||||||
|
return NextResponse.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the route should be protected
|
||||||
|
// This will apply to all API routes that start with /api/
|
||||||
|
if (request.nextUrl.pathname.startsWith('/api/')) {
|
||||||
|
if (!token || token !== tokenToUse) {
|
||||||
|
// Return a JSON response with 401 Unauthorized
|
||||||
|
return new NextResponse(JSON.stringify({ error: 'Unauthorized' }), {
|
||||||
|
status: 401,
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// For authorized users, continue
|
||||||
|
return NextResponse.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-API routes, just continue
|
||||||
|
return NextResponse.next();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure which paths this middleware will run on
|
||||||
|
export const config = {
|
||||||
|
matcher: [
|
||||||
|
// Apply to all API routes
|
||||||
|
'/api/:path*',
|
||||||
|
],
|
||||||
|
};
|
||||||
31
ui/src/utils/api.ts
Normal file
31
ui/src/utils/api.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import axios from 'axios';
|
||||||
|
import { createGlobalState } from 'react-global-hooks';
|
||||||
|
|
||||||
|
export const isAuthorizedState = createGlobalState(false);
|
||||||
|
|
||||||
|
export const apiClient = axios.create();
|
||||||
|
|
||||||
|
// Add a request interceptor to add token from localStorage
|
||||||
|
apiClient.interceptors.request.use(config => {
|
||||||
|
const token = localStorage.getItem('AI_TOOLKIT_AUTH');
|
||||||
|
if (token) {
|
||||||
|
config.headers['Authorization'] = `Bearer ${token}`;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add a response interceptor to handle 401 errors
|
||||||
|
apiClient.interceptors.response.use(
|
||||||
|
response => response, // Return successful responses as-is
|
||||||
|
error => {
|
||||||
|
// Check if the error is a 401 Unauthorized
|
||||||
|
if (error.response && error.response.status === 401) {
|
||||||
|
// Clear the auth token from localStorage
|
||||||
|
localStorage.removeItem('AI_TOOLKIT_AUTH');
|
||||||
|
isAuthorizedState.set(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject the promise with the error so calling code can still catch it
|
||||||
|
return Promise.reject(error);
|
||||||
|
},
|
||||||
|
);
|
||||||
@@ -2,3 +2,4 @@ export const objectCopy = <T>(obj: T): T => {
|
|||||||
return JSON.parse(JSON.stringify(obj)) as T;
|
return JSON.parse(JSON.stringify(obj)) as T;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const wait = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ export function useNestedState<T>(initialState: T): [T, (value: any, path?: stri
|
|||||||
const setValue = React.useCallback((value: any, path?: string) => {
|
const setValue = React.useCallback((value: any, path?: string) => {
|
||||||
if (path === undefined) {
|
if (path === undefined) {
|
||||||
setState(value);
|
setState(value);
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
setState(prevState => setNestedValue(prevState, value, path));
|
setState(prevState => setNestedValue(prevState, value, path));
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
return [state, setValue];
|
return [state, setValue];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import { JobConfig } from '@/types';
|
import { JobConfig } from '@/types';
|
||||||
import { Job } from '@prisma/client';
|
import { Job } from '@prisma/client';
|
||||||
|
import { apiClient } from '@/utils/api';
|
||||||
|
|
||||||
export const startJob = (jobID: string) => {
|
export const startJob = (jobID: string) => {
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
fetch(`/api/jobs/${jobID}/start`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs/${jobID}/start`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Job started:', data);
|
console.log('Job started:', data);
|
||||||
resolve();
|
resolve();
|
||||||
@@ -18,8 +20,9 @@ export const startJob = (jobID: string) => {
|
|||||||
|
|
||||||
export const stopJob = (jobID: string) => {
|
export const stopJob = (jobID: string) => {
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
fetch(`/api/jobs/${jobID}/stop`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs/${jobID}/stop`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Job stopped:', data);
|
console.log('Job stopped:', data);
|
||||||
resolve();
|
resolve();
|
||||||
@@ -33,8 +36,9 @@ export const stopJob = (jobID: string) => {
|
|||||||
|
|
||||||
export const deleteJob = (jobID: string) => {
|
export const deleteJob = (jobID: string) => {
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
fetch(`/api/jobs/${jobID}/delete`)
|
apiClient
|
||||||
.then(res => res.json())
|
.get(`/api/jobs/${jobID}/delete`)
|
||||||
|
.then(res => res.data)
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Job deleted:', data);
|
console.log('Job deleted:', data);
|
||||||
resolve();
|
resolve();
|
||||||
@@ -67,9 +71,9 @@ export const getAvaliableJobActions = (job: Job) => {
|
|||||||
export const getNumberOfSamples = (job: Job) => {
|
export const getNumberOfSamples = (job: Job) => {
|
||||||
const jobConfig = getJobConfig(job);
|
const jobConfig = getJobConfig(job);
|
||||||
return jobConfig.config.process[0].sample?.prompts?.length || 0;
|
return jobConfig.config.process[0].sample?.prompts?.length || 0;
|
||||||
}
|
};
|
||||||
|
|
||||||
export const getTotalSteps = (job: Job) => {
|
export const getTotalSteps = (job: Job) => {
|
||||||
const jobConfig = getJobConfig(job);
|
const jobConfig = getJobConfig(job);
|
||||||
return jobConfig.config.process[0].train.steps;
|
return jobConfig.config.process[0].train.steps;
|
||||||
}
|
};
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
import type { Config } from "tailwindcss";
|
import type { Config } from 'tailwindcss';
|
||||||
|
|
||||||
const config: Config = {
|
const config: Config = {
|
||||||
content: [
|
content: [
|
||||||
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
|
'./src/pages/**/*.{js,ts,jsx,tsx,mdx}',
|
||||||
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
|
'./src/components/**/*.{js,ts,jsx,tsx,mdx}',
|
||||||
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
|
'./src/app/**/*.{js,ts,jsx,tsx,mdx}',
|
||||||
],
|
],
|
||||||
darkMode: "class",
|
darkMode: 'class',
|
||||||
theme: {
|
theme: {
|
||||||
extend: {
|
extend: {
|
||||||
colors: {
|
colors: {
|
||||||
gray: {
|
gray: {
|
||||||
950: "#0a0a0a",
|
950: '#0a0a0a',
|
||||||
900: "#171717",
|
900: '#171717',
|
||||||
800: "#262626",
|
800: '#262626',
|
||||||
700: "#404040",
|
700: '#404040',
|
||||||
600: "#525252",
|
600: '#525252',
|
||||||
500: "#737373",
|
500: '#737373',
|
||||||
400: "#a3a3a3",
|
400: '#a3a3a3',
|
||||||
300: "#d4d4d4",
|
300: '#d4d4d4',
|
||||||
200: "#e5e5e5",
|
200: '#e5e5e5',
|
||||||
100: "#f5f5f5",
|
100: '#f5f5f5',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -28,4 +28,4 @@ const config: Config = {
|
|||||||
plugins: [],
|
plugins: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
export default config;
|
export default config;
|
||||||
|
|||||||
Reference in New Issue
Block a user