From 3a6b24f4c8cac9c0bf49a0f4e216cf68b3416850 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 20 Mar 2025 08:07:09 -0600 Subject: [PATCH] Added a way to secure the UI. Plus various bug fixes and quality of life updates --- README.md | 43 +++--- extensions_built_in/sd_trainer/UITrainer.py | 5 +- ui/package.json | 1 + ui/src/app/api/auth/route.ts | 6 + ui/src/app/api/datasets/create/route.tsx | 2 +- ui/src/app/api/datasets/listImages/route.ts | 28 ++-- ui/src/app/api/files/[...filePath]/route.ts | 21 +-- ui/src/app/api/gpu/route.ts | 9 +- ui/src/app/api/img/caption/route.ts | 1 - ui/src/app/api/jobs/[jobID]/start/route.ts | 4 +- ui/src/app/api/jobs/route.ts | 2 - ui/src/app/api/settings/route.ts | 2 +- ui/src/app/datasets/[datasetName]/page.tsx | 15 +- ui/src/app/datasets/page.tsx | 24 +-- ui/src/app/jobs/new/AdvancedJob.tsx | 6 +- ui/src/app/jobs/new/jobConfig.ts | 11 +- ui/src/app/jobs/new/options.ts | 2 +- ui/src/app/jobs/new/page.tsx | 59 ++++--- ui/src/app/jobs/page.tsx | 5 +- ui/src/app/layout.tsx | 19 ++- ui/src/app/page.tsx | 2 +- ui/src/app/settings/page.tsx | 29 ++-- ui/src/components/AddImagesModal.tsx | 86 +++++------ ui/src/components/AuthWrapper.tsx | 163 ++++++++++++++++++++ ui/src/components/Card.tsx | 2 +- ui/src/components/DatasetImageCard.tsx | 55 ++++--- ui/src/components/FilesWidget.tsx | 14 +- ui/src/components/GPUMonitor.tsx | 138 +++++++++-------- ui/src/components/GPUWidget.tsx | 16 +- ui/src/components/JobOverview.tsx | 6 +- ui/src/components/Loading.tsx | 2 +- ui/src/components/SampleImageCard.tsx | 9 +- ui/src/components/Sidebar.tsx | 5 +- ui/src/components/formInputs.tsx | 14 +- ui/src/hooks/useDatasetList.tsx | 6 +- ui/src/hooks/useFilesList.tsx | 6 +- ui/src/hooks/useGPUInfo.tsx | 12 +- ui/src/hooks/useJob.tsx | 8 +- ui/src/hooks/useJobsList.tsx | 6 +- ui/src/hooks/useSampleImages.tsx | 7 +- ui/src/hooks/useSettings.tsx | 8 +- ui/src/middleware.ts | 49 ++++++ ui/src/utils/api.ts | 31 ++++ ui/src/utils/basic.ts | 1 + ui/src/utils/hooks.tsx | 4 +- ui/src/utils/jobs.ts | 20 ++- ui/tailwind.config.ts | 32 ++-- 47 files changed, 618 insertions(+), 378 deletions(-) create mode 100644 ui/src/app/api/auth/route.ts create mode 100644 ui/src/components/AuthWrapper.tsx create mode 100644 ui/src/middleware.ts create mode 100644 ui/src/utils/api.ts diff --git a/README.md b/README.md index 7098d5cc..f92d3898 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ If you enjoy my work, or use it for commercial purposes, please consider sponsor Thank you to all my current supporters! -_Last updated: 2025-03-13_ +_Last updated: 2025-03-20_ ### GitHub Sponsors @@ -16,7 +16,7 @@ _Last updated: 2025-03-13_ ### Patreon Supporters -Abraham Irawan Al H Armin Behjati Bharat Prabhakar clement Delangue Cosmosis David Garrido Doron Adler Eli Slugworth EmmanuelMr18 Gili Ben Shahar HestoySeghuro . Jack Blakely Jack English Jason Jean-Tristan Marin Jodh Singh John Dopamine Joseph Rocca Kasım Açıkbaş Kristjan Retter Maciej Popławski Michael Levine Miguel Lara Misch Strotz Mohamed Oumoumad Noctre Patron Paul Fidika Plaidam Prasanth Veerina Razvan Grigore Steve Hanff Steven Simmons Sören The Local Lab Trent Hunter Un Defined Vladimir Sotnikov Wesley Reitzfeld Zoltán-Csaba Nyiró Алексей Наумов עומר מכלוף +Abraham Irawan Al H Armin Behjati Austin Robinson Bharat Prabhakar clement Delangue Cosmosis David Garrido Doron Adler Eli Slugworth EmmanuelMr18 Gili Ben Shahar HestoySeghuro . Jack Blakely Jack English Jason Jean-Tristan Marin Jodh Singh John Dopamine Joseph Rocca Kasım Açıkbaş Kristjan Retter Michael Levine Miguel Lara Misch Strotz Mohamed Oumoumad Noctre Patron Paul Fidika Plaidam Prasanth Veerina RayHell Razvan Grigore Steve Hanff Steven Simmons Sören The Local Lab Trent Hunter Un Defined Vladimir Sotnikov Wesley Reitzfeld Zoltán-Csaba Nyiró Алексей Наумов עומר מכלוף --- @@ -62,36 +62,39 @@ pip install -r requirements.txt AI Toolkit UI -The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It is still in early beta and will likely have bugs and frequent breaking changes. It is currently only tested on linux for now. +The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server. - -WARNING: The UI is not secure and should not be exposed to the internet. It is only meant to be run locally or on a server that does not have ports exposed. Adding additional security is on the roadmap. - -## Installing the UI +## Running the UI Requirements: - Node.js > 18 -You will need to do this with every update as well. +The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below +will install / update the UI and it's dependencies and start the UI. ```bash cd ui -npm install -npm run build -npm run update_db -``` - -## Running the UI - -Make sure you built it as shown above. The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. - -```bash -cd ui -npm run start +npm run build_and_start ``` You can now access the UI at `http://localhost:8675` or `http://:8675` if you are running it on a server. +## Securing the UI + +If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token. +You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access +the UI. You can set this when starting the UI like so: + +```bash +# Linux +AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start + +# Windows +set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start + +# Windows Powershell +$env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start +``` ## FLUX.1 Training diff --git a/extensions_built_in/sd_trainer/UITrainer.py b/extensions_built_in/sd_trainer/UITrainer.py index feaf4345..124f80a8 100644 --- a/extensions_built_in/sd_trainer/UITrainer.py +++ b/extensions_built_in/sd_trainer/UITrainer.py @@ -15,7 +15,8 @@ class UITrainer(SDTrainer): super(UITrainer, self).__init__(process_id, job, config, **kwargs) self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db") if not os.path.exists(self.sqlite_db_path): - raise Exception(f"SQLite database not found at {self.sqlite_db_path}") + raise Exception( + f"SQLite database not found at {self.sqlite_db_path}") print(f"Using SQLite database at {self.sqlite_db_path}") self.job_id = os.environ.get("AITK_JOB_ID", None) self.job_id = self.job_id.strip() if self.job_id is not None else None @@ -147,6 +148,8 @@ class UITrainer(SDTrainer): try: await asyncio.gather(*self._async_tasks) + except Exception as e: + print(f"Error waiting for async operations: {e}") finally: # Clear the task list after completion self._async_tasks.clear() diff --git a/ui/package.json b/ui/package.json index b08d2e58..6775b5ba 100644 --- a/ui/package.json +++ b/ui/package.json @@ -6,6 +6,7 @@ "dev": "next dev --turbopack", "build": "next build", "start": "next start --port 8675", + "build_and_start": "npm install && npm run update_db && npm run build && npm run start", "lint": "next lint", "update_db": "npx prisma generate && npx prisma db push", "format": "prettier --write \"**/*.{js,jsx,ts,tsx,css,scss}\"" diff --git a/ui/src/app/api/auth/route.ts b/ui/src/app/api/auth/route.ts new file mode 100644 index 00000000..1dc22973 --- /dev/null +++ b/ui/src/app/api/auth/route.ts @@ -0,0 +1,6 @@ +import { NextResponse } from 'next/server'; + +export async function GET() { + // if this gets hit, auth has already been verified + return NextResponse.json({ isAuthenticated: true }); +} diff --git a/ui/src/app/api/datasets/create/route.tsx b/ui/src/app/api/datasets/create/route.tsx index 62976613..f55759dc 100644 --- a/ui/src/app/api/datasets/create/route.tsx +++ b/ui/src/app/api/datasets/create/route.tsx @@ -19,4 +19,4 @@ export async function POST(request: Request) { } catch (error) { return NextResponse.json({ error: 'Failed to create dataset' }, { status: 500 }); } -} \ No newline at end of file +} diff --git a/ui/src/app/api/datasets/listImages/route.ts b/ui/src/app/api/datasets/listImages/route.ts index b1bb496a..48f01a70 100644 --- a/ui/src/app/api/datasets/listImages/route.ts +++ b/ui/src/app/api/datasets/listImages/route.ts @@ -8,31 +8,25 @@ export async function POST(request: Request) { const body = await request.json(); const { datasetName } = body; const datasetFolder = path.join(datasetsPath, datasetName); - + try { // Check if folder exists if (!fs.existsSync(datasetFolder)) { - return NextResponse.json( - { error: `Folder '${datasetName}' not found` }, - { status: 404 } - ); + return NextResponse.json({ error: `Folder '${datasetName}' not found` }, { status: 404 }); } // Find all images recursively const imageFiles = findImagesRecursively(datasetFolder); - + // Format response const result = imageFiles.map(imgPath => ({ - img_path: imgPath + img_path: imgPath, })); - + return NextResponse.json({ images: result }); } catch (error) { console.error('Error finding images:', error); - return NextResponse.json( - { error: 'Failed to process request' }, - { status: 500 } - ); + return NextResponse.json({ error: 'Failed to process request' }, { status: 500 }); } } @@ -44,13 +38,13 @@ export async function POST(request: Request) { function findImagesRecursively(dir: string): string[] { const imageExtensions = ['.png', '.jpg', '.jpeg', '.webp']; let results: string[] = []; - + const items = fs.readdirSync(dir); - + for (const item of items) { const itemPath = path.join(dir, item); const stat = fs.statSync(itemPath); - + if (stat.isDirectory()) { // If it's a directory, recursively search it results = results.concat(findImagesRecursively(itemPath)); @@ -62,6 +56,6 @@ function findImagesRecursively(dir: string): string[] { } } } - + return results; -} \ No newline at end of file +} diff --git a/ui/src/app/api/files/[...filePath]/route.ts b/ui/src/app/api/files/[...filePath]/route.ts index 44076e40..38bf9d87 100644 --- a/ui/src/app/api/files/[...filePath]/route.ts +++ b/ui/src/app/api/files/[...filePath]/route.ts @@ -16,7 +16,8 @@ export async function GET(request: NextRequest, { params }: { params: { filePath const allowedDirs = [datasetRoot, trainingRoot]; // Security check: Ensure path is in allowed directory - const isAllowed = allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..'); + const isAllowed = + allowedDirs.some(allowedDir => decodedFilePath.startsWith(allowedDir)) && !decodedFilePath.includes('..'); if (!isAllowed) { console.warn(`Access denied: ${decodedFilePath} not in ${allowedDirs.join(', ')}`); @@ -62,7 +63,7 @@ export async function GET(request: NextRequest, { params }: { params: { filePath 'Accept-Ranges': 'bytes', 'Cache-Control': 'public, max-age=86400', 'Content-Disposition': `attachment; filename="${encodeURIComponent(filename)}"`, - 'X-Content-Type-Options': 'nosniff' + 'X-Content-Type-Options': 'nosniff', }; if (range) { @@ -70,12 +71,12 @@ export async function GET(request: NextRequest, { params }: { params: { filePath const parts = range.replace(/bytes=/, '').split('-'); const start = parseInt(parts[0], 10); const end = parts[1] ? parseInt(parts[1], 10) : Math.min(start + 10 * 1024 * 1024, stat.size - 1); // 10MB chunks - const chunkSize = (end - start) + 1; + const chunkSize = end - start + 1; - const fileStream = fs.createReadStream(decodedFilePath, { - start, + const fileStream = fs.createReadStream(decodedFilePath, { + start, end, - highWaterMark: 64 * 1024 // 64KB buffer + highWaterMark: 64 * 1024, // 64KB buffer }); return new NextResponse(fileStream as any, { @@ -83,19 +84,19 @@ export async function GET(request: NextRequest, { params }: { params: { filePath headers: { ...commonHeaders, 'Content-Range': `bytes ${start}-${end}/${stat.size}`, - 'Content-Length': String(chunkSize) + 'Content-Length': String(chunkSize), }, }); } else { // For full file download, read directly without streaming wrapper const fileStream = fs.createReadStream(decodedFilePath, { - highWaterMark: 64 * 1024 // 64KB buffer + highWaterMark: 64 * 1024, // 64KB buffer }); return new NextResponse(fileStream as any, { headers: { ...commonHeaders, - 'Content-Length': String(stat.size) + 'Content-Length': String(stat.size), }, }); } @@ -103,4 +104,4 @@ export async function GET(request: NextRequest, { params }: { params: { filePath console.error('Error serving file:', error); return new NextResponse('Internal Server Error', { status: 500 }); } -} \ No newline at end of file +} diff --git a/ui/src/app/api/gpu/route.ts b/ui/src/app/api/gpu/route.ts index c06e4a05..8b11dbb0 100644 --- a/ui/src/app/api/gpu/route.ts +++ b/ui/src/app/api/gpu/route.ts @@ -10,7 +10,7 @@ export async function GET() { // Get platform const platform = os.platform(); const isWindows = platform === 'win32'; - + // Check if nvidia-smi is available const hasNvidiaSmi = await checkNvidiaSmi(isWindows); @@ -61,8 +61,9 @@ async function checkNvidiaSmi(isWindows: boolean): Promise { async function getGpuStats(isWindows: boolean) { // Command is the same for both platforms, but the path might be different - const command = 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; - + const command = + 'nvidia-smi --query-gpu=index,name,driver_version,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,power.draw,power.limit,clocks.current.graphics,clocks.current.memory,fan.speed --format=csv,noheader,nounits'; + // Execute command const { stdout } = await execAsync(command); @@ -117,4 +118,4 @@ async function getGpuStats(isWindows: boolean) { }); return gpus; -} \ No newline at end of file +} diff --git a/ui/src/app/api/img/caption/route.ts b/ui/src/app/api/img/caption/route.ts index d3a150c9..df4235f9 100644 --- a/ui/src/app/api/img/caption/route.ts +++ b/ui/src/app/api/img/caption/route.ts @@ -17,7 +17,6 @@ export async function POST(request: Request) { return NextResponse.json({ error: 'Image does not exist' }, { status: 404 }); } - // check for caption const captionPath = imgPath.replace(/\.[^/.]+$/, '') + '.txt'; // save caption to file diff --git a/ui/src/app/api/jobs/[jobID]/start/route.ts b/ui/src/app/api/jobs/[jobID]/start/route.ts index 5067e90d..e8260713 100644 --- a/ui/src/app/api/jobs/[jobID]/start/route.ts +++ b/ui/src/app/api/jobs/[jobID]/start/route.ts @@ -103,7 +103,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s cwd: TOOLKIT_ROOT, windowsHide: false, }); - + subprocess.unref(); } else { // For non-Windows platforms, use your original approach @@ -116,7 +116,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s }, cwd: TOOLKIT_ROOT, }); - + subprocess.unref(); } // const subprocess = spawn(pythonPath, [runFilePath, configPath], { diff --git a/ui/src/app/api/jobs/route.ts b/ui/src/app/api/jobs/route.ts index 66843cc9..4d9b954e 100644 --- a/ui/src/app/api/jobs/route.ts +++ b/ui/src/app/api/jobs/route.ts @@ -6,7 +6,6 @@ const prisma = new PrismaClient(); export async function GET(request: Request) { const { searchParams } = new URL(request.url); const id = searchParams.get('id'); - console.log('ID:', id); try { if (id) { @@ -19,7 +18,6 @@ export async function GET(request: Request) { const jobs = await prisma.job.findMany({ orderBy: { created_at: 'desc' }, }); - console.log('Jobs:', jobs); return NextResponse.json({ jobs: jobs }); } catch (error) { console.error(error); diff --git a/ui/src/app/api/settings/route.ts b/ui/src/app/api/settings/route.ts index 055cfbb6..62528cdd 100644 --- a/ui/src/app/api/settings/route.ts +++ b/ui/src/app/api/settings/route.ts @@ -1,7 +1,7 @@ import { NextResponse } from 'next/server'; import { PrismaClient } from '@prisma/client'; import { defaultTrainFolder, defaultDatasetsFolder } from '@/paths'; -import {flushCache} from '@/server/settings'; +import { flushCache } from '@/server/settings'; const prisma = new PrismaClient(); diff --git a/ui/src/app/datasets/[datasetName]/page.tsx b/ui/src/app/datasets/[datasetName]/page.tsx index 37976431..d3bf68fc 100644 --- a/ui/src/app/datasets/[datasetName]/page.tsx +++ b/ui/src/app/datasets/[datasetName]/page.tsx @@ -6,6 +6,7 @@ import DatasetImageCard from '@/components/DatasetImageCard'; import { Button } from '@headlessui/react'; import AddImagesModal, { openImagesModal } from '@/components/AddImagesModal'; import { TopBar, MainContent } from '@/components/layout'; +import { apiClient } from '@/utils/api'; export default function DatasetPage({ params }: { params: { datasetName: string } }) { const [imgList, setImgList] = useState<{ img_path: string }[]>([]); @@ -15,15 +16,11 @@ export default function DatasetPage({ params }: { params: { datasetName: string const refreshImageList = (dbName: string) => { setStatus('loading'); - fetch('/api/datasets/listImages', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ datasetName: dbName }), - }) - .then(res => res.json()) - .then(data => { + console.log('Fetching images for dataset:', dbName); + apiClient + .post('/api/datasets/listImages', { datasetName: dbName }) + .then((res: any) => { + const data = res.data; console.log('Images:', data.images); // sort data.images.sort((a: { img_path: string }, b: { img_path: string }) => a.img_path.localeCompare(b.img_path)); diff --git a/ui/src/app/datasets/page.tsx b/ui/src/app/datasets/page.tsx index 9f5aeba2..c650153d 100644 --- a/ui/src/app/datasets/page.tsx +++ b/ui/src/app/datasets/page.tsx @@ -10,6 +10,7 @@ import { FaRegTrashAlt } from 'react-icons/fa'; import { openConfirm } from '@/components/ConfirmModal'; import { TopBar, MainContent } from '@/components/layout'; import UniversalTable, { TableColumn } from '@/components/UniversalTable'; +import { apiClient } from '@/utils/api'; export default function Datasets() { const { datasets, status, refreshDatasets } = useDatasetList(); @@ -54,16 +55,10 @@ export default function Datasets() { type: 'warning', confirmText: 'Delete', onConfirm: () => { - fetch('/api/datasets/delete', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ name: datasetName }), - }) - .then(res => res.json()) - .then(data => { - console.log('Dataset deleted:', data); + apiClient + .post('/api/datasets/delete', { name: datasetName }) + .then(() => { + console.log('Dataset deleted:', datasetName); refreshDatasets(); }) .catch(error => { @@ -76,14 +71,7 @@ export default function Datasets() { const handleCreateDataset = async (e: React.FormEvent) => { e.preventDefault(); try { - const response = await fetch('/api/datasets/create', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ name: newDatasetName }), - }); - const data = await response.json(); + const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data); console.log('New dataset created:', data); refreshDatasets(); setNewDatasetName(''); diff --git a/ui/src/app/jobs/new/AdvancedJob.tsx b/ui/src/app/jobs/new/AdvancedJob.tsx index 9d0873e1..42365158 100644 --- a/ui/src/app/jobs/new/AdvancedJob.tsx +++ b/ui/src/app/jobs/new/AdvancedJob.tsx @@ -33,11 +33,7 @@ const yamlConfig: YAML.DocumentOptions & directives: true, }; -export default function AdvancedJob({ - jobConfig, - setJobConfig, - settings, -}: Props) { +export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) { const [editorValue, setEditorValue] = useState(''); const lastJobConfigUpdateStringRef = useRef(''); const editorRef = useRef(null); diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 0f2a4311..5f090479 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -30,7 +30,7 @@ export const defaultJobConfig: JobConfig = { linear: 32, linear_alpha: 32, lokr_full_rank: true, - lokr_factor: -1 + lokr_factor: -1, }, save: { dtype: 'bf16', @@ -39,9 +39,7 @@ export const defaultJobConfig: JobConfig = { save_format: 'diffusers', push_to_hub: false, }, - datasets: [ - defaultDatasetConfig - ], + datasets: [defaultDatasetConfig], train: { batch_size: 1, bypass_guidance_embedding: true, @@ -55,7 +53,7 @@ export const defaultJobConfig: JobConfig = { timestep_type: 'sigmoid', content_or_style: 'balanced', optimizer_params: { - weight_decay: 1e-4 + weight_decay: 1e-4, }, unload_text_encoder: false, lr: 0.0001, @@ -66,8 +64,7 @@ export const defaultJobConfig: JobConfig = { dtype: 'bf16', diff_output_preservation: false, diff_output_preservation_multiplier: 1.0, - diff_output_preservation_class: 'person' - + diff_output_preservation_class: 'person', }, model: { name_or_path: 'ostris/Flex.1-alpha', diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index e45f02ee..17c05b79 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -12,7 +12,7 @@ export const modelArchs = [ { name: 'flux', label: 'Flux.1' }, { name: 'wan21', label: 'Wan 2.1' }, { name: 'lumina2', label: 'Lumina2' }, -] +]; export const isVideoModelFromArch = (arch: string) => { const videoArches = ['wan21']; diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index 65474cdc..d1b7f9ca 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -19,6 +19,7 @@ import { Button } from '@headlessui/react'; import { FaChevronLeft } from 'react-icons/fa'; import SimpleJob from './SimpleJob'; import AdvancedJob from './AdvancedJob'; +import { apiClient } from '@/utils/api'; const isDev = process.env.NODE_ENV === 'development'; @@ -56,12 +57,13 @@ export default function TrainingForm() { useEffect(() => { if (runId) { - fetch(`/api/jobs?id=${runId}`) - .then(res => res.json()) + apiClient + .get(`/api/jobs?id=${runId}`) + .then(res => res.data) .then(data => { + console.log('Training:', data); setGpuIDs(data.gpu_ids); setJobConfig(JSON.parse(data.job_config)); - // setJobConfig(data.name, 'config.name'); }) .catch(error => console.error('Error fetching training:', error)); } @@ -85,33 +87,30 @@ export default function TrainingForm() { if (status === 'saving') return; setStatus('saving'); - try { - const response = await fetch('/api/jobs', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - id: runId, - name: jobConfig.config.name, - gpu_ids: gpuIDs, - job_config: jobConfig, - }), - }); - - if (!response.ok) throw new Error('Failed to save training'); - - setStatus('success'); - if (!runId) { - const data = await response.json(); - router.push(`/jobs/${data.id}`); - } - setTimeout(() => setStatus('idle'), 2000); - } catch (error) { - console.error('Error saving training:', error); - setStatus('error'); - setTimeout(() => setStatus('idle'), 2000); - } + apiClient + .post('/api/jobs', { + id: runId, + name: jobConfig.config.name, + gpu_ids: gpuIDs, + job_config: jobConfig, + }) + .then(res => { + setStatus('success'); + if (runId) { + router.push(`/jobs/${runId}`); + } else { + router.push(`/jobs/${res.data.id}`); + } + }) + .catch(error => { + console.error('Error saving training:', error); + setStatus('error'); + }) + .finally(() => + setTimeout(() => { + setStatus('idle'); + }, 2000), + ); }; const handleSubmit = async (e: React.FormEvent) => { diff --git a/ui/src/app/jobs/page.tsx b/ui/src/app/jobs/page.tsx index 211ab433..5e80b313 100644 --- a/ui/src/app/jobs/page.tsx +++ b/ui/src/app/jobs/page.tsx @@ -13,10 +13,7 @@ export default function Dashboard() {
- + New Training Job
diff --git a/ui/src/app/layout.tsx b/ui/src/app/layout.tsx index 292c78c4..427909bd 100644 --- a/ui/src/app/layout.tsx +++ b/ui/src/app/layout.tsx @@ -6,6 +6,7 @@ import { ThemeProvider } from '@/components/ThemeProvider'; import ConfirmModal from '@/components/ConfirmModal'; import SampleImageModal from '@/components/SampleImageModal'; import { Suspense } from 'react'; +import AuthWrapper from '@/components/AuthWrapper'; const inter = Inter({ subsets: ['latin'] }); @@ -15,6 +16,9 @@ export const metadata: Metadata = { }; export default function RootLayout({ children }: { children: React.ReactNode }) { + // Check if the AI_TOOLKIT_AUTH environment variable is set + const authRequired = process.env.AI_TOOLKIT_AUTH ? true : false; + return ( @@ -22,13 +26,14 @@ export default function RootLayout({ children }: { children: React.ReactNode }) -
- - -
- {children} -
-
+ +
+ +
+ {children} +
+
+
diff --git a/ui/src/app/page.tsx b/ui/src/app/page.tsx index 3b59b023..f889cb61 100644 --- a/ui/src/app/page.tsx +++ b/ui/src/app/page.tsx @@ -2,4 +2,4 @@ import { redirect } from 'next/navigation'; export default function Home() { redirect('/dashboard'); -} \ No newline at end of file +} diff --git a/ui/src/app/settings/page.tsx b/ui/src/app/settings/page.tsx index 0c8cdf6b..4bf257c3 100644 --- a/ui/src/app/settings/page.tsx +++ b/ui/src/app/settings/page.tsx @@ -3,6 +3,7 @@ import { useEffect, useState } from 'react'; import useSettings from '@/hooks/useSettings'; import { TopBar, MainContent } from '@/components/layout'; +import { apiClient } from '@/utils/api'; export default function Settings() { const { settings, setSettings } = useSettings(); @@ -12,24 +13,18 @@ export default function Settings() { e.preventDefault(); setStatus('saving'); - try { - const response = await fetch('/api/settings', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(settings), + apiClient + .post('/api/settings', settings) + .then(() => { + setStatus('success'); + }) + .catch(error => { + console.error('Error saving settings:', error); + setStatus('error'); + }) + .finally(() => { + setTimeout(() => setStatus('idle'), 2000); }); - - if (!response.ok) throw new Error('Failed to save settings'); - - setStatus('success'); - setTimeout(() => setStatus('idle'), 2000); - } catch (error) { - console.error('Error saving settings:', error); - setStatus('error'); - setTimeout(() => setStatus('idle'), 2000); - } }; const handleChange = (e: React.ChangeEvent) => { diff --git a/ui/src/components/AddImagesModal.tsx b/ui/src/components/AddImagesModal.tsx index 55cc6746..b95b512e 100644 --- a/ui/src/components/AddImagesModal.tsx +++ b/ui/src/components/AddImagesModal.tsx @@ -4,7 +4,7 @@ import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/re import { FaUpload } from 'react-icons/fa'; import { useCallback, useState } from 'react'; import { useDropzone } from 'react-dropzone'; -import axios from 'axios'; +import { apiClient } from '@/utils/api'; export interface AddImagesModalState { datasetName: string; @@ -15,7 +15,7 @@ export const addImagesModalState = createGlobalState export const openImagesModal = (datasetName: string, onComplete: () => void) => { addImagesModalState.set({ datasetName, onComplete }); -} +}; export default function AddImagesModal() { const [addImagesModalInfo, setAddImagesModalInfo] = addImagesModalState.use(); @@ -36,46 +36,49 @@ export default function AddImagesModal() { } }; - const onDrop = useCallback(async (acceptedFiles: File[]) => { - if (acceptedFiles.length === 0) return; + const onDrop = useCallback( + async (acceptedFiles: File[]) => { + if (acceptedFiles.length === 0) return; - setIsUploading(true); - setUploadProgress(0); - - const formData = new FormData(); - acceptedFiles.forEach(file => { - formData.append('files', file); - }); - formData.append('datasetName', addImagesModalInfo?.datasetName || ''); - - try { - await axios.post(`/api/datasets/upload`, formData, { - headers: { - 'Content-Type': 'multipart/form-data', - }, - onUploadProgress: (progressEvent) => { - const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); - setUploadProgress(percentCompleted); - }, - timeout: 0, // Disable timeout - }); - - onDone(); - } catch (error) { - console.error('Upload failed:', error); - } finally { - setIsUploading(false); + setIsUploading(true); setUploadProgress(0); - } - }, [addImagesModalInfo]); + + const formData = new FormData(); + acceptedFiles.forEach(file => { + formData.append('files', file); + }); + formData.append('datasetName', addImagesModalInfo?.datasetName || ''); + + try { + await apiClient.post(`/api/datasets/upload`, formData, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + onUploadProgress: progressEvent => { + const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100)); + setUploadProgress(percentCompleted); + }, + timeout: 0, // Disable timeout + }); + + onDone(); + } catch (error) { + console.error('Upload failed:', error); + } finally { + setIsUploading(false); + setUploadProgress(0); + } + }, + [addImagesModalInfo], + ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, accept: { 'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp'], - 'text/*': ['.txt'] + 'text/*': ['.txt'], }, - multiple: true + multiple: true, }); return ( @@ -105,22 +108,15 @@ export default function AddImagesModal() {

- {isDragActive - ? 'Drop the files here...' - : 'Drag & drop files here, or click to select files'} + {isDragActive ? 'Drop the files here...' : 'Drag & drop files here, or click to select files'}

{isUploading && (
-
+
-

- Uploading... {uploadProgress}% -

+

Uploading... {uploadProgress}%

)} @@ -152,4 +148,4 @@ export default function AddImagesModal() { ); -} \ No newline at end of file +} diff --git a/ui/src/components/AuthWrapper.tsx b/ui/src/components/AuthWrapper.tsx new file mode 100644 index 00000000..58f1bad5 --- /dev/null +++ b/ui/src/components/AuthWrapper.tsx @@ -0,0 +1,163 @@ +'use client'; + +import { useState, useEffect, useRef } from 'react'; +import { apiClient, isAuthorizedState } from '@/utils/api'; +import { createGlobalState } from 'react-global-hooks'; + +interface AuthWrapperProps { + authRequired: boolean; + children: React.ReactNode | React.ReactNode[]; +} + +export default function AuthWrapper({ authRequired, children }: AuthWrapperProps) { + const [token, setToken] = useState(''); + // start with true, and deauth if needed + const [isAuthorizedGlobal, setIsAuthorized] = isAuthorizedState.use(); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(''); + const [isBrowser, setIsBrowser] = useState(false); + const inputRef = useRef(null); + + const isAuthorized = authRequired ? isAuthorizedGlobal : true; + + // Set isBrowser to true when component mounts + useEffect(() => { + setIsBrowser(true); + // Get token from localStorage only after component has mounted + const storedToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; + setToken(storedToken); + checkAuth(); + }, []); + + // auto focus on input when not authorized + useEffect(() => { + if (isAuthorized) { + return; + } + setTimeout(() => { + if (inputRef.current) { + inputRef.current.focus(); + } + }, 100); + }, [isAuthorized]); + + const checkAuth = async () => { + // always get current stored token here to avoid state race conditions + const currentToken = localStorage.getItem('AI_TOOLKIT_AUTH') || ''; + if (!authRequired || isLoading || currentToken === '') { + return; + } + setIsLoading(true); + setError(''); + try { + const response = await apiClient.get('/api/auth'); + if (response.data.isAuthenticated) { + setIsAuthorized(true); + } else { + setIsAuthorized(false); + setError('Invalid token. Please try again.'); + } + } catch (err) { + setIsAuthorized(false); + console.log(err); + setError('Invalid token. Please try again.'); + } + setIsLoading(false); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setError(''); + + if (!token.trim()) { + setError('Please enter your token'); + return; + } + + if (isBrowser) { + localStorage.setItem('AI_TOOLKIT_AUTH', token); + checkAuth(); + } + }; + + if (isAuthorized) { + return <>{children}; + } + + return ( +
+ {/* Left side - decorative or brand area */} +
+
+ {/* Replace with your own logo */} +
+ Ostris AI Toolkit +
+
+

AI Toolkit

+
+ + {/* Right side - login form */} +
+
+
+ {/* Mobile logo */} +
+ Ostris AI Toolkit +
+
+ +

AI Toolkit

+ +
+
+ + setToken(e.target.value)} + className="w-full px-4 py-3 rounded-lg bg-gray-800 border border-gray-700 focus:border-blue-500 focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50 text-gray-100 transition duration-200" + placeholder="Enter your token" + /> +
+ + {error && ( +
{error}
+ )} + + +
+
+
+
+ ); +} diff --git a/ui/src/components/Card.tsx b/ui/src/components/Card.tsx index 1942d81e..13c7409b 100644 --- a/ui/src/components/Card.tsx +++ b/ui/src/components/Card.tsx @@ -12,4 +12,4 @@ const Card: React.FC = ({ title, children }) => { ); }; -export default Card; \ No newline at end of file +export default Card; diff --git a/ui/src/components/DatasetImageCard.tsx b/ui/src/components/DatasetImageCard.tsx index 0dd2e243..9bf58d43 100644 --- a/ui/src/components/DatasetImageCard.tsx +++ b/ui/src/components/DatasetImageCard.tsx @@ -2,6 +2,7 @@ import React, { useRef, useEffect, useState, ReactNode, KeyboardEvent } from 're import { FaTrashAlt } from 'react-icons/fa'; import { openConfirm } from './ConfirmModal'; import classNames from 'classnames'; +import { apiClient } from '@/utils/api'; interface DatasetImageCardProps { imageUrl: string; @@ -27,30 +28,32 @@ const DatasetImageCard: React.FC = ({ const isGettingCaption = useRef(false); const fetchCaption = async () => { - try { - if (isGettingCaption.current || isCaptionLoaded) return; - isGettingCaption.current = true; - const response = await fetch(`/api/caption/${encodeURIComponent(imageUrl)}`); - const data = await response.text(); - setCaption(data); - setSavedCaption(data); - setIsCaptionLoaded(true); - } catch (error) { - console.error('Error fetching caption:', error); - } + if (isGettingCaption.current || isCaptionLoaded) return; + isGettingCaption.current = true; + apiClient + .get(`/api/caption/${encodeURIComponent(imageUrl)}`) + .then(res => res.data) + .then(data => { + console.log('Caption fetched:', data); + + setCaption(data || ''); + setSavedCaption(data || ''); + setIsCaptionLoaded(true); + }) + .catch(error => { + console.error('Error fetching caption:', error); + }) + .finally(() => { + isGettingCaption.current = false; + }); }; const saveCaption = () => { const trimmedCaption = caption.trim(); if (trimmedCaption === savedCaption) return; - fetch('/api/img/caption', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ imgPath: imageUrl, caption: trimmedCaption }), - }) - .then(res => res.json()) + apiClient + .post('/api/img/caption', { imgPath: imageUrl, caption: trimmedCaption }) + .then(res => res.data) .then(data => { console.log('Caption saved:', data); setSavedCaption(trimmedCaption); @@ -129,16 +132,10 @@ const DatasetImageCard: React.FC = ({ type: 'warning', confirmText: 'Delete', onConfirm: () => { - fetch('/api/img/delete', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ imgPath: imageUrl }), - }) - .then(res => res.json()) - .then(data => { - console.log('Image deleted:', data); + apiClient + .post('/api/img/delete', { imgPath: imageUrl }) + .then(() => { + console.log('Image deleted:', imageUrl); onDelete(); }) .catch(error => { diff --git a/ui/src/components/FilesWidget.tsx b/ui/src/components/FilesWidget.tsx index 9c4754f8..b73dabf0 100644 --- a/ui/src/components/FilesWidget.tsx +++ b/ui/src/components/FilesWidget.tsx @@ -24,9 +24,7 @@ export default function FilesWidget({ jobID }: { jobID: string }) {

Checkpoints

- - {files.length} - + {files.length}
@@ -50,9 +48,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) { const fileName = file.path.split('/').pop() || ''; const nameWithoutExt = fileName.replace('.safetensors', ''); return ( - @@ -80,11 +78,9 @@ export default function FilesWidget({ jobID }: { jobID: string }) { )} {['success', 'refreshing'].includes(status) && files.length === 0 && ( -
- No checkpoints available -
+
No checkpoints available
)} ); -} \ No newline at end of file +} diff --git a/ui/src/components/GPUMonitor.tsx b/ui/src/components/GPUMonitor.tsx index fedba605..1f255fc7 100644 --- a/ui/src/components/GPUMonitor.tsx +++ b/ui/src/components/GPUMonitor.tsx @@ -1,7 +1,8 @@ -import React, { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useRef, useMemo } from 'react'; import { GPUApiResponse } from '@/types'; import Loading from '@/components/Loading'; import GPUWidget from '@/components/GPUWidget'; +import { apiClient } from '@/utils/api'; const GpuMonitor: React.FC = () => { const [gpuData, setGpuData] = useState(null); @@ -12,27 +13,26 @@ const GpuMonitor: React.FC = () => { useEffect(() => { const fetchGpuInfo = async () => { - try { - if (isFetchingGpuRef.current) { - return; - } - isFetchingGpuRef.current = true; - const response = await fetch('/api/gpu'); - - if (!response.ok) { - throw new Error(`HTTP error! Status: ${response.status}`); - } - - const data: GPUApiResponse = await response.json(); - setGpuData(data); - setLastUpdated(new Date()); - setError(null); - } catch (err) { - setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`); - } finally { - isFetchingGpuRef.current = false; - setLoading(false); + if (isFetchingGpuRef.current) { + return; } + setLoading(true); + isFetchingGpuRef.current = true; + apiClient + .get('/api/gpu') + .then(res => res.data) + .then(data => { + setGpuData(data); + setLastUpdated(new Date()); + setError(null); + }) + .catch(err => { + setError(`Failed to fetch GPU data: ${err instanceof Error ? err.message : String(err)}`); + }) + .finally(() => { + isFetchingGpuRef.current = false; + setLoading(false); + }); }; // Fetch immediately on component mount @@ -69,46 +69,63 @@ const GpuMonitor: React.FC = () => { } }; - if (loading) { - return ; - } + console.log('state', { + loading, + gpuData, + error, + lastUpdated, + }); + + const content = useMemo(() => { + if (loading && !gpuData) { + return ; + } + + if (error) { + return ( +
+ Error! + {error} +
+ ); + } + + if (!gpuData) { + return ( +
+ No GPU data available. +
+ ); + } + + if (!gpuData.hasNvidiaSmi) { + return ( +
+ No NVIDIA GPUs detected! + nvidia-smi is not available on this system. + {gpuData.error &&

{gpuData.error}

} +
+ ); + } + + if (gpuData.gpus.length === 0) { + return ( +
+ No GPUs found, but nvidia-smi is available. +
+ ); + } + + const gridClass = getGridClasses(gpuData?.gpus?.length || 1); - if (error) { return ( -
- Error! - {error} +
+ {gpuData.gpus.map((gpu, idx) => ( + + ))}
); - } - - if (!gpuData) { - return ( -
- No GPU data available. -
- ); - } - - if (!gpuData.hasNvidiaSmi) { - return ( -
- No NVIDIA GPUs detected! - nvidia-smi is not available on this system. - {gpuData.error &&

{gpuData.error}

} -
- ); - } - - if (gpuData.gpus.length === 0) { - return ( -
- No GPUs found, but nvidia-smi is available. -
- ); - } - - const gridClass = getGridClasses(gpuData.gpus.length); + }, [loading, gpuData, error]); return (
@@ -116,12 +133,7 @@ const GpuMonitor: React.FC = () => {

GPU Monitor

Last updated: {lastUpdated?.toLocaleTimeString()}
- -
- {gpuData.gpus.map((gpu, idx) => ( - - ))} -
+ {content}
); }; diff --git a/ui/src/components/GPUWidget.tsx b/ui/src/components/GPUWidget.tsx index bc0dfc62..e21b0190 100644 --- a/ui/src/components/GPUWidget.tsx +++ b/ui/src/components/GPUWidget.tsx @@ -24,9 +24,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {

{gpu.name}

- - #{gpu.index} - + #{gpu.index}
@@ -38,18 +36,14 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {

Temperature

-

- {gpu.temperature}°C -

+

{gpu.temperature}°C

Fan Speed

-

- {gpu.fan.speed}% -

+

{gpu.fan.speed}%

@@ -99,7 +93,7 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) {

Power Draw

{gpu.power.draw?.toFixed(1)}W - / {gpu.power.limit?.toFixed(1) || " ? "}W + / {gpu.power.limit?.toFixed(1) || ' ? '}W

@@ -107,4 +101,4 @@ export default function GPUWidget({ gpu }: GPUWidgetProps) { ); -} \ No newline at end of file +} diff --git a/ui/src/components/JobOverview.tsx b/ui/src/components/JobOverview.tsx index 2580374c..a29d4fb0 100644 --- a/ui/src/components/JobOverview.tsx +++ b/ui/src/components/JobOverview.tsx @@ -45,7 +45,9 @@ export default function JobOverview({ job }: JobOverviewProps) { {/* Job Information Panel */}
-

{job.info}

+

+ {job.info} +

{job.status}
@@ -85,7 +87,7 @@ export default function JobOverview({ job }: JobOverviewProps) {

Speed

-

{job.speed_string == "" ? "?" : job.speed_string}

+

{job.speed_string == '' ? '?' : job.speed_string}

diff --git a/ui/src/components/Loading.tsx b/ui/src/components/Loading.tsx index 08093792..141a3103 100644 --- a/ui/src/components/Loading.tsx +++ b/ui/src/components/Loading.tsx @@ -1,6 +1,6 @@ export default function Loading() { return ( -
+
); diff --git a/ui/src/components/SampleImageCard.tsx b/ui/src/components/SampleImageCard.tsx index 857d3ffa..947c594a 100644 --- a/ui/src/components/SampleImageCard.tsx +++ b/ui/src/components/SampleImageCard.tsx @@ -11,7 +11,14 @@ interface SampleImageCardProps { onDelete?: () => void; } -const SampleImageCard: React.FC = ({ imageUrl, alt, numSamples, sampleImages, children, className = '' }) => { +const SampleImageCard: React.FC = ({ + imageUrl, + alt, + numSamples, + sampleImages, + children, + className = '', +}) => { const cardRef = useRef(null); const [isVisible, setIsVisible] = useState(false); const [loaded, setLoaded] = useState(false); diff --git a/ui/src/components/Sidebar.tsx b/ui/src/components/Sidebar.tsx index 1f50d570..14ec4542 100644 --- a/ui/src/components/Sidebar.tsx +++ b/ui/src/components/Sidebar.tsx @@ -1,9 +1,10 @@ import Link from 'next/link'; -import { Home, Settings, BrainCircuit, Images } from 'lucide-react'; +import { Home, Settings, BrainCircuit, Images, Plus } from 'lucide-react'; const Sidebar = () => { const navigation = [ { name: 'Dashboard', href: '/dashboard', icon: Home }, + { name: 'New Job', href: '/jobs/new', icon: Plus }, { name: 'Training Jobs', href: '/jobs', icon: BrainCircuit }, { name: 'Datasets', href: '/datasets', icon: Images }, { name: 'Settings', href: '/settings', icon: Settings }, @@ -33,7 +34,7 @@ const Sidebar = () => {
-
+
{ const { label, value, onChange, placeholder, required, min, max } = props; - + // Add controlled internal state to properly handle partial inputs const [inputValue, setInputValue] = React.useState(value ?? ''); @@ -66,7 +66,7 @@ export const NumberInput = (props: NumberInputProps) => { value={inputValue} onChange={e => { const rawValue = e.target.value; - + // Update the input display with the raw value setInputValue(rawValue); @@ -81,7 +81,7 @@ export const NumberInput = (props: NumberInputProps) => { // Only apply constraints and call onChange when we have a valid number if (!isNaN(numValue)) { let constrainedValue = numValue; - + // Apply min/max constraints if they exist if (min !== undefined && constrainedValue < min) { constrainedValue = min; @@ -89,7 +89,7 @@ export const NumberInput = (props: NumberInputProps) => { if (max !== undefined && constrainedValue > max) { constrainedValue = max; } - + onChange(constrainedValue); } }} @@ -152,14 +152,14 @@ export const Checkbox = (props: CheckboxProps) => { className={classNames( 'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2', checked ? 'bg-blue-600' : 'bg-gray-700', - disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80' + disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80', )} > Toggle {label} @@ -168,7 +168,7 @@ export const Checkbox = (props: CheckboxProps) => { htmlFor={id} className={classNames( 'text-sm font-medium cursor-pointer select-none', - disabled ? 'text-gray-500' : 'text-gray-300' + disabled ? 'text-gray-500' : 'text-gray-300', )} > {label} diff --git a/ui/src/hooks/useDatasetList.tsx b/ui/src/hooks/useDatasetList.tsx index 5760a947..480e36d0 100644 --- a/ui/src/hooks/useDatasetList.tsx +++ b/ui/src/hooks/useDatasetList.tsx @@ -1,6 +1,7 @@ 'use client'; import { useEffect, useState } from 'react'; +import { apiClient } from '@/utils/api'; export default function useDatasetList() { const [datasets, setDatasets] = useState([]); @@ -8,8 +9,9 @@ export default function useDatasetList() { const refreshDatasets = () => { setStatus('loading'); - fetch('/api/datasets/list') - .then(res => res.json()) + apiClient + .get('/api/datasets/list') + .then(res => res.data) .then(data => { console.log('Datasets:', data); // sort diff --git a/ui/src/hooks/useFilesList.tsx b/ui/src/hooks/useFilesList.tsx index 7268a3cb..e73e6c69 100644 --- a/ui/src/hooks/useFilesList.tsx +++ b/ui/src/hooks/useFilesList.tsx @@ -1,6 +1,7 @@ 'use client'; import { useEffect, useState, useRef } from 'react'; +import { apiClient } from '@/utils/api'; interface FileObject { path: string; @@ -18,8 +19,9 @@ export default function useFilesList(jobID: string, reloadInterval: null | numbe loadStatus = 'refreshing'; } setStatus(loadStatus); - fetch(`/api/jobs/${jobID}/files`) - .then(res => res.json()) + apiClient + .get(`/api/jobs/${jobID}/files`) + .then(res => res.data) .then(data => { console.log('Fetched files:', data); if (data.files) { diff --git a/ui/src/hooks/useGPUInfo.tsx b/ui/src/hooks/useGPUInfo.tsx index 5f2eda38..b8f60405 100644 --- a/ui/src/hooks/useGPUInfo.tsx +++ b/ui/src/hooks/useGPUInfo.tsx @@ -2,6 +2,7 @@ import { GPUApiResponse, GpuInfo } from '@/types'; import { useEffect, useState } from 'react'; +import { apiClient } from '@/utils/api'; export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) { const [gpuList, setGpuList] = useState([]); @@ -11,18 +12,11 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva const fetchGpuInfo = async () => { setStatus('loading'); try { - const response = await fetch('/api/gpu'); - - if (!response.ok) { - throw new Error(`HTTP error! Status: ${response.status}`); - } - - const data: GPUApiResponse = await response.json(); + const data: GPUApiResponse = await apiClient.get('/api/gpu').then(res => res.data); let gpus = data.gpus.sort((a, b) => a.index - b.index); if (gpuIds) { gpus = gpus.filter(gpu => gpuIds.includes(gpu.index)); } - setGpuList(gpus); setStatus('success'); } catch (err) { @@ -51,4 +45,4 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva }, [gpuIds, reloadInterval]); // Added dependencies return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo }; -} \ No newline at end of file +} diff --git a/ui/src/hooks/useJob.tsx b/ui/src/hooks/useJob.tsx index e4318233..5c43f9e5 100644 --- a/ui/src/hooks/useJob.tsx +++ b/ui/src/hooks/useJob.tsx @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react'; import { Job } from '@prisma/client'; +import { apiClient } from '@/utils/api'; export default function useJob(jobID: string, reloadInterval: null | number = null) { const [job, setJob] = useState(null); @@ -9,8 +10,9 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu const refreshJob = () => { setStatus('loading'); - fetch(`/api/jobs?id=${jobID}`) - .then(res => res.json()) + apiClient + .get(`/api/jobs?id=${jobID}`) + .then(res => res.data) .then(data => { console.log('Job:', data); setJob(data); @@ -32,7 +34,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu return () => { clearInterval(interval); - } + }; } }, [jobID]); diff --git a/ui/src/hooks/useJobsList.tsx b/ui/src/hooks/useJobsList.tsx index a1c3d2d7..6f1e3af9 100644 --- a/ui/src/hooks/useJobsList.tsx +++ b/ui/src/hooks/useJobsList.tsx @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react'; import { Job } from '@prisma/client'; +import { apiClient } from '@/utils/api'; export default function useJobsList(onlyActive = false) { const [jobs, setJobs] = useState([]); @@ -9,8 +10,9 @@ export default function useJobsList(onlyActive = false) { const refreshJobs = () => { setStatus('loading'); - fetch('/api/jobs') - .then(res => res.json()) + apiClient + .get('/api/jobs') + .then(res => res.data) .then(data => { console.log('Jobs:', data); if (data.error) { diff --git a/ui/src/hooks/useSampleImages.tsx b/ui/src/hooks/useSampleImages.tsx index ccf07493..8b79a8d6 100644 --- a/ui/src/hooks/useSampleImages.tsx +++ b/ui/src/hooks/useSampleImages.tsx @@ -1,6 +1,7 @@ 'use client'; import { useEffect, useState } from 'react'; +import { apiClient } from '@/utils/api'; export default function useSampleImages(jobID: string, reloadInterval: null | number = null) { const [sampleImages, setSampleImages] = useState([]); @@ -8,9 +9,11 @@ export default function useSampleImages(jobID: string, reloadInterval: null | nu const refreshSampleImages = () => { setStatus('loading'); - fetch(`/api/jobs/${jobID}/samples`) - .then(res => res.json()) + apiClient + .get(`/api/jobs/${jobID}/samples`) + .then(res => res.data) .then(data => { + console.log('Fetched sample images:', data); if (data.samples) { setSampleImages(data.samples); } diff --git a/ui/src/hooks/useSettings.tsx b/ui/src/hooks/useSettings.tsx index 7d17bbd5..35fcc538 100644 --- a/ui/src/hooks/useSettings.tsx +++ b/ui/src/hooks/useSettings.tsx @@ -1,6 +1,7 @@ 'use client'; import { useEffect, useState } from 'react'; +import { apiClient } from '@/utils/api'; export interface Settings { HF_TOKEN: string; @@ -16,10 +17,11 @@ export default function useSettings() { }); const [isSettingsLoaded, setIsLoaded] = useState(false); useEffect(() => { - // Fetch current settings - fetch('/api/settings') - .then(res => res.json()) + apiClient + .get('/api/settings') + .then(res => res.data) .then(data => { + console.log('Settings:', data); setSettings({ HF_TOKEN: data.HF_TOKEN || '', TRAINING_FOLDER: data.TRAINING_FOLDER || '', diff --git a/ui/src/middleware.ts b/ui/src/middleware.ts new file mode 100644 index 00000000..bf198d1e --- /dev/null +++ b/ui/src/middleware.ts @@ -0,0 +1,49 @@ +// middleware.ts (at the root of your project) +import { NextResponse } from 'next/server'; +import type { NextRequest } from 'next/server'; + +// if route starts with these, approve +const publicRoutes = ['/api/img/', '/api/files/']; + +export function middleware(request: NextRequest) { + // check env var for AI_TOOLKIT_AUTH, if not set, approve all requests + // if it is set make sure bearer token matches + const tokenToUse = process.env.AI_TOOLKIT_AUTH || null; + if (!tokenToUse) { + return NextResponse.next(); + } + + // Get the token from the headers + const token = request.headers.get('Authorization')?.split(' ')[1]; + + // allow public routes to pass through + if (publicRoutes.some(route => request.nextUrl.pathname.startsWith(route))) { + return NextResponse.next(); + } + + // Check if the route should be protected + // This will apply to all API routes that start with /api/ + if (request.nextUrl.pathname.startsWith('/api/')) { + if (!token || token !== tokenToUse) { + // Return a JSON response with 401 Unauthorized + return new NextResponse(JSON.stringify({ error: 'Unauthorized' }), { + status: 401, + headers: { 'Content-Type': 'application/json' }, + }); + } + + // For authorized users, continue + return NextResponse.next(); + } + + // For non-API routes, just continue + return NextResponse.next(); +} + +// Configure which paths this middleware will run on +export const config = { + matcher: [ + // Apply to all API routes + '/api/:path*', + ], +}; diff --git a/ui/src/utils/api.ts b/ui/src/utils/api.ts new file mode 100644 index 00000000..5bf3716e --- /dev/null +++ b/ui/src/utils/api.ts @@ -0,0 +1,31 @@ +import axios from 'axios'; +import { createGlobalState } from 'react-global-hooks'; + +export const isAuthorizedState = createGlobalState(false); + +export const apiClient = axios.create(); + +// Add a request interceptor to add token from localStorage +apiClient.interceptors.request.use(config => { + const token = localStorage.getItem('AI_TOOLKIT_AUTH'); + if (token) { + config.headers['Authorization'] = `Bearer ${token}`; + } + return config; +}); + +// Add a response interceptor to handle 401 errors +apiClient.interceptors.response.use( + response => response, // Return successful responses as-is + error => { + // Check if the error is a 401 Unauthorized + if (error.response && error.response.status === 401) { + // Clear the auth token from localStorage + localStorage.removeItem('AI_TOOLKIT_AUTH'); + isAuthorizedState.set(false); + } + + // Reject the promise with the error so calling code can still catch it + return Promise.reject(error); + }, +); diff --git a/ui/src/utils/basic.ts b/ui/src/utils/basic.ts index a06e7ee6..29bff697 100644 --- a/ui/src/utils/basic.ts +++ b/ui/src/utils/basic.ts @@ -2,3 +2,4 @@ export const objectCopy = (obj: T): T => { return JSON.parse(JSON.stringify(obj)) as T; }; +export const wait = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); diff --git a/ui/src/utils/hooks.tsx b/ui/src/utils/hooks.tsx index a3a66b12..f96af344 100644 --- a/ui/src/utils/hooks.tsx +++ b/ui/src/utils/hooks.tsx @@ -79,10 +79,10 @@ export function useNestedState(initialState: T): [T, (value: any, path?: stri const setValue = React.useCallback((value: any, path?: string) => { if (path === undefined) { setState(value); - return + return; } setState(prevState => setNestedValue(prevState, value, path)); }, []); return [state, setValue]; -} \ No newline at end of file +} diff --git a/ui/src/utils/jobs.ts b/ui/src/utils/jobs.ts index 93624d12..3a74870e 100644 --- a/ui/src/utils/jobs.ts +++ b/ui/src/utils/jobs.ts @@ -1,10 +1,12 @@ import { JobConfig } from '@/types'; import { Job } from '@prisma/client'; +import { apiClient } from '@/utils/api'; export const startJob = (jobID: string) => { return new Promise((resolve, reject) => { - fetch(`/api/jobs/${jobID}/start`) - .then(res => res.json()) + apiClient + .get(`/api/jobs/${jobID}/start`) + .then(res => res.data) .then(data => { console.log('Job started:', data); resolve(); @@ -18,8 +20,9 @@ export const startJob = (jobID: string) => { export const stopJob = (jobID: string) => { return new Promise((resolve, reject) => { - fetch(`/api/jobs/${jobID}/stop`) - .then(res => res.json()) + apiClient + .get(`/api/jobs/${jobID}/stop`) + .then(res => res.data) .then(data => { console.log('Job stopped:', data); resolve(); @@ -33,8 +36,9 @@ export const stopJob = (jobID: string) => { export const deleteJob = (jobID: string) => { return new Promise((resolve, reject) => { - fetch(`/api/jobs/${jobID}/delete`) - .then(res => res.json()) + apiClient + .get(`/api/jobs/${jobID}/delete`) + .then(res => res.data) .then(data => { console.log('Job deleted:', data); resolve(); @@ -67,9 +71,9 @@ export const getAvaliableJobActions = (job: Job) => { export const getNumberOfSamples = (job: Job) => { const jobConfig = getJobConfig(job); return jobConfig.config.process[0].sample?.prompts?.length || 0; -} +}; export const getTotalSteps = (job: Job) => { const jobConfig = getJobConfig(job); return jobConfig.config.process[0].train.steps; -} +}; diff --git a/ui/tailwind.config.ts b/ui/tailwind.config.ts index 31f4dc3e..433a6ade 100644 --- a/ui/tailwind.config.ts +++ b/ui/tailwind.config.ts @@ -1,26 +1,26 @@ -import type { Config } from "tailwindcss"; +import type { Config } from 'tailwindcss'; const config: Config = { content: [ - "./src/pages/**/*.{js,ts,jsx,tsx,mdx}", - "./src/components/**/*.{js,ts,jsx,tsx,mdx}", - "./src/app/**/*.{js,ts,jsx,tsx,mdx}", + './src/pages/**/*.{js,ts,jsx,tsx,mdx}', + './src/components/**/*.{js,ts,jsx,tsx,mdx}', + './src/app/**/*.{js,ts,jsx,tsx,mdx}', ], - darkMode: "class", + darkMode: 'class', theme: { extend: { colors: { gray: { - 950: "#0a0a0a", - 900: "#171717", - 800: "#262626", - 700: "#404040", - 600: "#525252", - 500: "#737373", - 400: "#a3a3a3", - 300: "#d4d4d4", - 200: "#e5e5e5", - 100: "#f5f5f5", + 950: '#0a0a0a', + 900: '#171717', + 800: '#262626', + 700: '#404040', + 600: '#525252', + 500: '#737373', + 400: '#a3a3a3', + 300: '#d4d4d4', + 200: '#e5e5e5', + 100: '#f5f5f5', }, }, }, @@ -28,4 +28,4 @@ const config: Config = { plugins: [], }; -export default config; \ No newline at end of file +export default config;