+
{
const { label, value, onChange, placeholder, required, min, max } = props;
-
+
// Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState(value ?? '');
@@ -66,7 +66,7 @@ export const NumberInput = (props: NumberInputProps) => {
value={inputValue}
onChange={e => {
const rawValue = e.target.value;
-
+
// Update the input display with the raw value
setInputValue(rawValue);
@@ -81,7 +81,7 @@ export const NumberInput = (props: NumberInputProps) => {
// Only apply constraints and call onChange when we have a valid number
if (!isNaN(numValue)) {
let constrainedValue = numValue;
-
+
// Apply min/max constraints if they exist
if (min !== undefined && constrainedValue < min) {
constrainedValue = min;
@@ -89,7 +89,7 @@ export const NumberInput = (props: NumberInputProps) => {
if (max !== undefined && constrainedValue > max) {
constrainedValue = max;
}
-
+
onChange(constrainedValue);
}
}}
@@ -152,14 +152,14 @@ export const Checkbox = (props: CheckboxProps) => {
className={classNames(
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-600 focus:ring-offset-2',
checked ? 'bg-blue-600' : 'bg-gray-700',
- disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80'
+ disabled ? 'opacity-50 cursor-not-allowed' : 'hover:bg-opacity-80',
)}
>
Toggle {label}
@@ -168,7 +168,7 @@ export const Checkbox = (props: CheckboxProps) => {
htmlFor={id}
className={classNames(
'text-sm font-medium cursor-pointer select-none',
- disabled ? 'text-gray-500' : 'text-gray-300'
+ disabled ? 'text-gray-500' : 'text-gray-300',
)}
>
{label}
diff --git a/ui/src/hooks/useDatasetList.tsx b/ui/src/hooks/useDatasetList.tsx
index 5760a947..480e36d0 100644
--- a/ui/src/hooks/useDatasetList.tsx
+++ b/ui/src/hooks/useDatasetList.tsx
@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
+import { apiClient } from '@/utils/api';
export default function useDatasetList() {
const [datasets, setDatasets] = useState([]);
@@ -8,8 +9,9 @@ export default function useDatasetList() {
const refreshDatasets = () => {
setStatus('loading');
- fetch('/api/datasets/list')
- .then(res => res.json())
+ apiClient
+ .get('/api/datasets/list')
+ .then(res => res.data)
.then(data => {
console.log('Datasets:', data);
// sort
diff --git a/ui/src/hooks/useFilesList.tsx b/ui/src/hooks/useFilesList.tsx
index 7268a3cb..e73e6c69 100644
--- a/ui/src/hooks/useFilesList.tsx
+++ b/ui/src/hooks/useFilesList.tsx
@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState, useRef } from 'react';
+import { apiClient } from '@/utils/api';
interface FileObject {
path: string;
@@ -18,8 +19,9 @@ export default function useFilesList(jobID: string, reloadInterval: null | numbe
loadStatus = 'refreshing';
}
setStatus(loadStatus);
- fetch(`/api/jobs/${jobID}/files`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs/${jobID}/files`)
+ .then(res => res.data)
.then(data => {
console.log('Fetched files:', data);
if (data.files) {
diff --git a/ui/src/hooks/useGPUInfo.tsx b/ui/src/hooks/useGPUInfo.tsx
index 5f2eda38..b8f60405 100644
--- a/ui/src/hooks/useGPUInfo.tsx
+++ b/ui/src/hooks/useGPUInfo.tsx
@@ -2,6 +2,7 @@
import { GPUApiResponse, GpuInfo } from '@/types';
import { useEffect, useState } from 'react';
+import { apiClient } from '@/utils/api';
export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterval: null | number = null) {
const [gpuList, setGpuList] = useState([]);
@@ -11,18 +12,11 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
const fetchGpuInfo = async () => {
setStatus('loading');
try {
- const response = await fetch('/api/gpu');
-
- if (!response.ok) {
- throw new Error(`HTTP error! Status: ${response.status}`);
- }
-
- const data: GPUApiResponse = await response.json();
+ const data: GPUApiResponse = await apiClient.get('/api/gpu').then(res => res.data);
let gpus = data.gpus.sort((a, b) => a.index - b.index);
if (gpuIds) {
gpus = gpus.filter(gpu => gpuIds.includes(gpu.index));
}
-
setGpuList(gpus);
setStatus('success');
} catch (err) {
@@ -51,4 +45,4 @@ export default function useGPUInfo(gpuIds: null | number[] = null, reloadInterva
}, [gpuIds, reloadInterval]); // Added dependencies
return { gpuList, setGpuList, isGPUInfoLoaded, status, refreshGpuInfo: fetchGpuInfo };
-}
\ No newline at end of file
+}
diff --git a/ui/src/hooks/useJob.tsx b/ui/src/hooks/useJob.tsx
index e4318233..5c43f9e5 100644
--- a/ui/src/hooks/useJob.tsx
+++ b/ui/src/hooks/useJob.tsx
@@ -2,6 +2,7 @@
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
+import { apiClient } from '@/utils/api';
export default function useJob(jobID: string, reloadInterval: null | number = null) {
const [job, setJob] = useState(null);
@@ -9,8 +10,9 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
const refreshJob = () => {
setStatus('loading');
- fetch(`/api/jobs?id=${jobID}`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs?id=${jobID}`)
+ .then(res => res.data)
.then(data => {
console.log('Job:', data);
setJob(data);
@@ -32,7 +34,7 @@ export default function useJob(jobID: string, reloadInterval: null | number = nu
return () => {
clearInterval(interval);
- }
+ };
}
}, [jobID]);
diff --git a/ui/src/hooks/useJobsList.tsx b/ui/src/hooks/useJobsList.tsx
index a1c3d2d7..6f1e3af9 100644
--- a/ui/src/hooks/useJobsList.tsx
+++ b/ui/src/hooks/useJobsList.tsx
@@ -2,6 +2,7 @@
import { useEffect, useState } from 'react';
import { Job } from '@prisma/client';
+import { apiClient } from '@/utils/api';
export default function useJobsList(onlyActive = false) {
const [jobs, setJobs] = useState([]);
@@ -9,8 +10,9 @@ export default function useJobsList(onlyActive = false) {
const refreshJobs = () => {
setStatus('loading');
- fetch('/api/jobs')
- .then(res => res.json())
+ apiClient
+ .get('/api/jobs')
+ .then(res => res.data)
.then(data => {
console.log('Jobs:', data);
if (data.error) {
diff --git a/ui/src/hooks/useSampleImages.tsx b/ui/src/hooks/useSampleImages.tsx
index ccf07493..8b79a8d6 100644
--- a/ui/src/hooks/useSampleImages.tsx
+++ b/ui/src/hooks/useSampleImages.tsx
@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
+import { apiClient } from '@/utils/api';
export default function useSampleImages(jobID: string, reloadInterval: null | number = null) {
const [sampleImages, setSampleImages] = useState([]);
@@ -8,9 +9,11 @@ export default function useSampleImages(jobID: string, reloadInterval: null | nu
const refreshSampleImages = () => {
setStatus('loading');
- fetch(`/api/jobs/${jobID}/samples`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs/${jobID}/samples`)
+ .then(res => res.data)
.then(data => {
+ console.log('Fetched sample images:', data);
if (data.samples) {
setSampleImages(data.samples);
}
diff --git a/ui/src/hooks/useSettings.tsx b/ui/src/hooks/useSettings.tsx
index 7d17bbd5..35fcc538 100644
--- a/ui/src/hooks/useSettings.tsx
+++ b/ui/src/hooks/useSettings.tsx
@@ -1,6 +1,7 @@
'use client';
import { useEffect, useState } from 'react';
+import { apiClient } from '@/utils/api';
export interface Settings {
HF_TOKEN: string;
@@ -16,10 +17,11 @@ export default function useSettings() {
});
const [isSettingsLoaded, setIsLoaded] = useState(false);
useEffect(() => {
- // Fetch current settings
- fetch('/api/settings')
- .then(res => res.json())
+ apiClient
+ .get('/api/settings')
+ .then(res => res.data)
.then(data => {
+ console.log('Settings:', data);
setSettings({
HF_TOKEN: data.HF_TOKEN || '',
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
diff --git a/ui/src/middleware.ts b/ui/src/middleware.ts
new file mode 100644
index 00000000..bf198d1e
--- /dev/null
+++ b/ui/src/middleware.ts
@@ -0,0 +1,49 @@
+// middleware.ts (at the root of your project)
+import { NextResponse } from 'next/server';
+import type { NextRequest } from 'next/server';
+
+// if route starts with these, approve
+const publicRoutes = ['/api/img/', '/api/files/'];
+
+export function middleware(request: NextRequest) {
+ // check env var for AI_TOOLKIT_AUTH, if not set, approve all requests
+ // if it is set make sure bearer token matches
+ const tokenToUse = process.env.AI_TOOLKIT_AUTH || null;
+ if (!tokenToUse) {
+ return NextResponse.next();
+ }
+
+ // Get the token from the headers
+ const token = request.headers.get('Authorization')?.split(' ')[1];
+
+ // allow public routes to pass through
+ if (publicRoutes.some(route => request.nextUrl.pathname.startsWith(route))) {
+ return NextResponse.next();
+ }
+
+ // Check if the route should be protected
+ // This will apply to all API routes that start with /api/
+ if (request.nextUrl.pathname.startsWith('/api/')) {
+ if (!token || token !== tokenToUse) {
+ // Return a JSON response with 401 Unauthorized
+ return new NextResponse(JSON.stringify({ error: 'Unauthorized' }), {
+ status: 401,
+ headers: { 'Content-Type': 'application/json' },
+ });
+ }
+
+ // For authorized users, continue
+ return NextResponse.next();
+ }
+
+ // For non-API routes, just continue
+ return NextResponse.next();
+}
+
+// Configure which paths this middleware will run on
+export const config = {
+ matcher: [
+ // Apply to all API routes
+ '/api/:path*',
+ ],
+};
diff --git a/ui/src/utils/api.ts b/ui/src/utils/api.ts
new file mode 100644
index 00000000..5bf3716e
--- /dev/null
+++ b/ui/src/utils/api.ts
@@ -0,0 +1,31 @@
+import axios from 'axios';
+import { createGlobalState } from 'react-global-hooks';
+
+export const isAuthorizedState = createGlobalState(false);
+
+export const apiClient = axios.create();
+
+// Add a request interceptor to add token from localStorage
+apiClient.interceptors.request.use(config => {
+ const token = localStorage.getItem('AI_TOOLKIT_AUTH');
+ if (token) {
+ config.headers['Authorization'] = `Bearer ${token}`;
+ }
+ return config;
+});
+
+// Add a response interceptor to handle 401 errors
+apiClient.interceptors.response.use(
+ response => response, // Return successful responses as-is
+ error => {
+ // Check if the error is a 401 Unauthorized
+ if (error.response && error.response.status === 401) {
+ // Clear the auth token from localStorage
+ localStorage.removeItem('AI_TOOLKIT_AUTH');
+ isAuthorizedState.set(false);
+ }
+
+ // Reject the promise with the error so calling code can still catch it
+ return Promise.reject(error);
+ },
+);
diff --git a/ui/src/utils/basic.ts b/ui/src/utils/basic.ts
index a06e7ee6..29bff697 100644
--- a/ui/src/utils/basic.ts
+++ b/ui/src/utils/basic.ts
@@ -2,3 +2,4 @@ export const objectCopy = (obj: T): T => {
return JSON.parse(JSON.stringify(obj)) as T;
};
+export const wait = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));
diff --git a/ui/src/utils/hooks.tsx b/ui/src/utils/hooks.tsx
index a3a66b12..f96af344 100644
--- a/ui/src/utils/hooks.tsx
+++ b/ui/src/utils/hooks.tsx
@@ -79,10 +79,10 @@ export function useNestedState(initialState: T): [T, (value: any, path?: stri
const setValue = React.useCallback((value: any, path?: string) => {
if (path === undefined) {
setState(value);
- return
+ return;
}
setState(prevState => setNestedValue(prevState, value, path));
}, []);
return [state, setValue];
-}
\ No newline at end of file
+}
diff --git a/ui/src/utils/jobs.ts b/ui/src/utils/jobs.ts
index 93624d12..3a74870e 100644
--- a/ui/src/utils/jobs.ts
+++ b/ui/src/utils/jobs.ts
@@ -1,10 +1,12 @@
import { JobConfig } from '@/types';
import { Job } from '@prisma/client';
+import { apiClient } from '@/utils/api';
export const startJob = (jobID: string) => {
return new Promise((resolve, reject) => {
- fetch(`/api/jobs/${jobID}/start`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs/${jobID}/start`)
+ .then(res => res.data)
.then(data => {
console.log('Job started:', data);
resolve();
@@ -18,8 +20,9 @@ export const startJob = (jobID: string) => {
export const stopJob = (jobID: string) => {
return new Promise((resolve, reject) => {
- fetch(`/api/jobs/${jobID}/stop`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs/${jobID}/stop`)
+ .then(res => res.data)
.then(data => {
console.log('Job stopped:', data);
resolve();
@@ -33,8 +36,9 @@ export const stopJob = (jobID: string) => {
export const deleteJob = (jobID: string) => {
return new Promise((resolve, reject) => {
- fetch(`/api/jobs/${jobID}/delete`)
- .then(res => res.json())
+ apiClient
+ .get(`/api/jobs/${jobID}/delete`)
+ .then(res => res.data)
.then(data => {
console.log('Job deleted:', data);
resolve();
@@ -67,9 +71,9 @@ export const getAvaliableJobActions = (job: Job) => {
export const getNumberOfSamples = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].sample?.prompts?.length || 0;
-}
+};
export const getTotalSteps = (job: Job) => {
const jobConfig = getJobConfig(job);
return jobConfig.config.process[0].train.steps;
-}
+};
diff --git a/ui/tailwind.config.ts b/ui/tailwind.config.ts
index 31f4dc3e..433a6ade 100644
--- a/ui/tailwind.config.ts
+++ b/ui/tailwind.config.ts
@@ -1,26 +1,26 @@
-import type { Config } from "tailwindcss";
+import type { Config } from 'tailwindcss';
const config: Config = {
content: [
- "./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
- "./src/components/**/*.{js,ts,jsx,tsx,mdx}",
- "./src/app/**/*.{js,ts,jsx,tsx,mdx}",
+ './src/pages/**/*.{js,ts,jsx,tsx,mdx}',
+ './src/components/**/*.{js,ts,jsx,tsx,mdx}',
+ './src/app/**/*.{js,ts,jsx,tsx,mdx}',
],
- darkMode: "class",
+ darkMode: 'class',
theme: {
extend: {
colors: {
gray: {
- 950: "#0a0a0a",
- 900: "#171717",
- 800: "#262626",
- 700: "#404040",
- 600: "#525252",
- 500: "#737373",
- 400: "#a3a3a3",
- 300: "#d4d4d4",
- 200: "#e5e5e5",
- 100: "#f5f5f5",
+ 950: '#0a0a0a',
+ 900: '#171717',
+ 800: '#262626',
+ 700: '#404040',
+ 600: '#525252',
+ 500: '#737373',
+ 400: '#a3a3a3',
+ 300: '#d4d4d4',
+ 200: '#e5e5e5',
+ 100: '#f5f5f5',
},
},
},
@@ -28,4 +28,4 @@ const config: Config = {
plugins: [],
};
-export default config;
\ No newline at end of file
+export default config;