mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Setup a very basic ui
This commit is contained in:
38
ui/src/app/api/settings/route.ts
Normal file
38
ui/src/app/api/settings/route.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET() {
|
||||
try {
|
||||
const settings = await prisma.settings.findMany();
|
||||
return NextResponse.json(settings.reduce((acc, curr) => ({ ...acc, [curr.key]: curr.value }), {}));
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to fetch settings' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { HF_TOKEN, TRAINING_FOLDER } = body;
|
||||
|
||||
// Upsert both settings
|
||||
await Promise.all([
|
||||
prisma.settings.upsert({
|
||||
where: { key: 'HF_TOKEN' },
|
||||
update: { value: HF_TOKEN },
|
||||
create: { key: 'HF_TOKEN', value: HF_TOKEN },
|
||||
}),
|
||||
prisma.settings.upsert({
|
||||
where: { key: 'TRAINING_FOLDER' },
|
||||
update: { value: TRAINING_FOLDER },
|
||||
create: { key: 'TRAINING_FOLDER', value: TRAINING_FOLDER },
|
||||
}),
|
||||
]);
|
||||
|
||||
return NextResponse.json({ success: true });
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to update settings' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
55
ui/src/app/api/training/route.ts
Normal file
55
ui/src/app/api/training/route.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import { PrismaClient } from '@prisma/client';
|
||||
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams } = new URL(request.url);
|
||||
const id = searchParams.get('id');
|
||||
|
||||
try {
|
||||
if (id) {
|
||||
const training = await prisma.training.findUnique({
|
||||
where: { id },
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
}
|
||||
|
||||
const trainings = await prisma.training.findMany({
|
||||
orderBy: { created_at: 'desc' },
|
||||
});
|
||||
return NextResponse.json(trainings);
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to fetch training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: Request) {
|
||||
try {
|
||||
const body = await request.json();
|
||||
const { id, name, run_data } = body;
|
||||
|
||||
if (id) {
|
||||
// Update existing training
|
||||
const training = await prisma.training.update({
|
||||
where: { id },
|
||||
data: {
|
||||
name,
|
||||
run_data: JSON.stringify(run_data),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
} else {
|
||||
// Create new training
|
||||
const training = await prisma.training.create({
|
||||
data: {
|
||||
name,
|
||||
run_data: JSON.stringify(run_data),
|
||||
},
|
||||
});
|
||||
return NextResponse.json(training);
|
||||
}
|
||||
} catch (error) {
|
||||
return NextResponse.json({ error: 'Failed to save training data' }, { status: 500 });
|
||||
}
|
||||
}
|
||||
15
ui/src/app/dashboard/page.tsx
Normal file
15
ui/src/app/dashboard/page.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
export default function Dashboard() {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<h1 className="text-3xl font-bold">Dashboard</h1>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{[1, 2, 3].map(i => (
|
||||
<div key={i} className="p-6 bg-gray-800 rounded-lg">
|
||||
<h2 className="text-xl font-semibold mb-2">Card {i}</h2>
|
||||
<p className="text-gray-400">Example dashboard card content</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
BIN
ui/src/app/favicon.ico
Normal file
BIN
ui/src/app/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
21
ui/src/app/globals.css
Normal file
21
ui/src/app/globals.css
Normal file
@@ -0,0 +1,21 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
:root {
|
||||
--background: #ffffff;
|
||||
--foreground: #171717;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background: #0a0a0a;
|
||||
--foreground: #ededed;
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
color: var(--foreground);
|
||||
background: var(--background);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
}
|
||||
27
ui/src/app/layout.tsx
Normal file
27
ui/src/app/layout.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import type { Metadata } from 'next';
|
||||
import { Inter } from 'next/font/google';
|
||||
import './globals.css';
|
||||
import Sidebar from '@/components/Sidebar';
|
||||
import { ThemeProvider } from '@/components/ThemeProvider';
|
||||
|
||||
const inter = Inter({ subsets: ['latin'] });
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: 'Ostris - AI Toolkit',
|
||||
description: 'A toolkit for building AI things.',
|
||||
};
|
||||
|
||||
export default function RootLayout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<html lang="en" className="dark">
|
||||
<body className={inter.className}>
|
||||
<ThemeProvider>
|
||||
<div className="flex h-screen bg-gray-950">
|
||||
<Sidebar />
|
||||
<main className="flex-1 p-8 overflow-auto bg-gray-950 text-gray-100">{children}</main>
|
||||
</div>
|
||||
</ThemeProvider>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
5
ui/src/app/page.tsx
Normal file
5
ui/src/app/page.tsx
Normal file
@@ -0,0 +1,5 @@
|
||||
import { redirect } from 'next/navigation';
|
||||
|
||||
export default function Home() {
|
||||
redirect('/dashboard');
|
||||
}
|
||||
114
ui/src/app/settings/page.tsx
Normal file
114
ui/src/app/settings/page.tsx
Normal file
@@ -0,0 +1,114 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export default function Settings() {
|
||||
const [settings, setSettings] = useState({
|
||||
HF_TOKEN: '',
|
||||
TRAINING_FOLDER: '',
|
||||
});
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
useEffect(() => {
|
||||
// Fetch current settings
|
||||
fetch('/api/settings')
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setSettings({
|
||||
HF_TOKEN: data.HF_TOKEN || '',
|
||||
TRAINING_FOLDER: data.TRAINING_FOLDER || '',
|
||||
});
|
||||
})
|
||||
.catch(error => console.error('Error fetching settings:', error));
|
||||
}, []);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/settings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save settings');
|
||||
|
||||
setStatus('success');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving settings:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const { name, value } = e.target;
|
||||
setSettings(prev => ({ ...prev, [name]: value }));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="max-w-2xl mx-auto space-y-6">
|
||||
<h1 className="text-3xl font-bold mb-8">Settings</h1>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label htmlFor="HF_TOKEN" className="block text-sm font-medium mb-2">
|
||||
Hugging Face Token
|
||||
<div className="text-gray-500 text-sm ml-1">
|
||||
Create a Read token on{' '}
|
||||
<a href="https://huggingface.co/settings/tokens" target="_blank" rel="noreferrer">
|
||||
{' '}
|
||||
Huggingface
|
||||
</a> if you need to access gated/private models.
|
||||
</div>
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
id="HF_TOKEN"
|
||||
name="HF_TOKEN"
|
||||
value={settings.HF_TOKEN}
|
||||
onChange={handleChange}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter your Hugging Face token"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label htmlFor="TRAINING_FOLDER" className="block text-sm font-medium mb-2">
|
||||
Training Folder Path
|
||||
<div className="text-gray-500 text-sm ml-1">
|
||||
We will store your training information here. Must be an absolute path. If blank, it will default to the output folder in the project root.
|
||||
</div>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="TRAINING_FOLDER"
|
||||
name="TRAINING_FOLDER"
|
||||
value={settings.TRAINING_FOLDER}
|
||||
onChange={handleChange}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter training folder path"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={status === 'saving'}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : 'Save Settings'}
|
||||
</button>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Settings saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving settings. Please try again.</p>}
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
37
ui/src/app/train/options.ts
Normal file
37
ui/src/app/train/options.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
export interface Model {
|
||||
name_or_path: string;
|
||||
model_kwargs?: Record<string, boolean>;
|
||||
train_kwargs?: Record<string, boolean>;
|
||||
}
|
||||
|
||||
export interface Option {
|
||||
model: Model[];
|
||||
}
|
||||
|
||||
|
||||
export const options = {
|
||||
model: [
|
||||
{
|
||||
name_or_path: "ostris/Flex.1-alpha",
|
||||
model_kwargs: {
|
||||
"is_flux": true
|
||||
},
|
||||
train_kwargs: {
|
||||
"bypass_guidance_embedding": true
|
||||
}
|
||||
},
|
||||
{
|
||||
name_or_path: "black-forest-labs/FLUX.1-dev",
|
||||
model_kwargs: {
|
||||
"is_flux": true
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: "Alpha-VLLM/Lumina-Image-2.0",
|
||||
model_kwargs: {
|
||||
"is_lumina2": true
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
} as Option;
|
||||
150
ui/src/app/train/page.tsx
Normal file
150
ui/src/app/train/page.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { options } from './options';
|
||||
|
||||
interface TrainingData {
|
||||
modelConfig: {
|
||||
name_or_path: string;
|
||||
steps: number;
|
||||
batchSize: number;
|
||||
learningRate: number;
|
||||
};
|
||||
}
|
||||
|
||||
const defaultTrainingData: TrainingData = {
|
||||
modelConfig: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
steps: 100,
|
||||
batchSize: 32,
|
||||
learningRate: 0.001,
|
||||
},
|
||||
};
|
||||
|
||||
export default function TrainingForm() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const runId = searchParams.get('id');
|
||||
|
||||
const [name, setName] = useState('');
|
||||
const [trainingData, setTrainingData] = useState<TrainingData>(defaultTrainingData);
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
useEffect(() => {
|
||||
if (runId) {
|
||||
fetch(`/api/training?id=${runId}`)
|
||||
.then(res => res.json())
|
||||
.then(data => {
|
||||
setName(data.name);
|
||||
setTrainingData(JSON.parse(data.run_data));
|
||||
})
|
||||
.catch(error => console.error('Error fetching training:', error));
|
||||
}
|
||||
}, [runId]);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setStatus('saving');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/training', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
id: runId,
|
||||
name,
|
||||
run_data: trainingData,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) throw new Error('Failed to save training');
|
||||
|
||||
setStatus('success');
|
||||
if (!runId) {
|
||||
const data = await response.json();
|
||||
router.push(`/training?id=${data.id}`);
|
||||
}
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
} catch (error) {
|
||||
console.error('Error saving training:', error);
|
||||
setStatus('error');
|
||||
setTimeout(() => setStatus('idle'), 2000);
|
||||
}
|
||||
};
|
||||
|
||||
const updateSection = (section: keyof TrainingData, data: any) => {
|
||||
setTrainingData(prev => ({
|
||||
...prev,
|
||||
[section]: { ...prev[section], ...data },
|
||||
}));
|
||||
};
|
||||
|
||||
const modelOptions = options.model.map(model => model.name_or_path);
|
||||
|
||||
return (
|
||||
<div className="max-w-4xl mx-auto space-y-8 pb-12">
|
||||
<h1 className="text-3xl font-bold mb-8">{runId ? 'Edit Training Run' : 'New Training Run'}</h1>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className="space-y-4">
|
||||
<label htmlFor="name" className="block text-sm font-medium mb-2">
|
||||
Training Name
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="name"
|
||||
value={name}
|
||||
onChange={e => setName(e.target.value)}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg focus:ring-2 focus:ring-gray-600 focus:border-transparent"
|
||||
placeholder="Enter training name"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<section className="space-y-4 p-6 bg-gray-900 rounded-lg">
|
||||
<h2 className="text-xl font-bold mb-4">Model Configuration</h2>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2">Model</label>
|
||||
<select
|
||||
value={trainingData.modelConfig.name_or_path}
|
||||
onChange={e => updateSection('modelConfig', { name_or_path: e.target.value })}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
|
||||
>
|
||||
{modelOptions.map(model => (
|
||||
<option key={model} value={model}>
|
||||
{model}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm font-medium mb-2">Epochs</label>
|
||||
<input
|
||||
type="number"
|
||||
value={trainingData.modelConfig.steps}
|
||||
onChange={e => updateSection('modelConfig', { steps: Number(e.target.value) })}
|
||||
className="w-full px-4 py-2 bg-gray-800 border border-gray-700 rounded-lg"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={status === 'saving'}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{status === 'saving' ? 'Saving...' : runId ? 'Update Training' : 'Create Training'}
|
||||
</button>
|
||||
|
||||
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
|
||||
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
38
ui/src/components/Sidebar.tsx
Normal file
38
ui/src/components/Sidebar.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import Link from 'next/link';
|
||||
import { Home, Settings, BarChart2, BrainCircuit } from 'lucide-react';
|
||||
|
||||
const Sidebar = () => {
|
||||
const navigation = [
|
||||
{ name: 'Dashboard', href: '/dashboard', icon: Home },
|
||||
{ name: 'Train', href: '/train', icon: BrainCircuit },
|
||||
{ name: 'Settings', href: '/settings', icon: Settings },
|
||||
];
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-64 bg-gray-900 text-gray-100">
|
||||
<div className="p-4">
|
||||
<h1 className="text-xl">
|
||||
<img src="/ostris_logo.png" alt="Ostris AI Toolkit" className="w-auto h-8 mr-3 inline" />
|
||||
Ostris - AI Toolkit
|
||||
</h1>
|
||||
</div>
|
||||
<nav className="flex-1">
|
||||
<ul className="px-2 py-4 space-y-2">
|
||||
{navigation.map(item => (
|
||||
<li key={item.name}>
|
||||
<Link
|
||||
href={item.href}
|
||||
className="flex items-center px-4 py-2 text-gray-300 hover:bg-gray-800 rounded-lg transition-colors"
|
||||
>
|
||||
<item.icon className="w-5 h-5 mr-3" />
|
||||
{item.name}
|
||||
</Link>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</nav>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Sidebar;
|
||||
11
ui/src/components/ThemeProvider.tsx
Normal file
11
ui/src/components/ThemeProvider.tsx
Normal file
@@ -0,0 +1,11 @@
|
||||
'use client';
|
||||
|
||||
import { createContext, useContext, useEffect, useState } from 'react';
|
||||
|
||||
const ThemeContext = createContext({ isDark: true });
|
||||
|
||||
export const ThemeProvider = ({ children }: { children: React.ReactNode }) => {
|
||||
const [isDark, setIsDark] = useState(true);
|
||||
|
||||
return <ThemeContext.Provider value={{ isDark }}>{children}</ThemeContext.Provider>;
|
||||
};
|
||||
Reference in New Issue
Block a user