Made it easy to add control images to the samples in the UI

This commit is contained in:
Jaret Burkett
2025-07-17 12:00:48 -06:00
parent e25d2feddf
commit 8610c6ed7f
16 changed files with 400 additions and 57 deletions

3
.gitignore vendored
View File

@@ -180,4 +180,5 @@ cython_debug/
.DS_Store
._.DS_Store
aitk_db.db
/notes.md
/notes.md
/data

View File

@@ -425,6 +425,9 @@ Everything else should work the same including layer targeting.
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
### Jul 17, 2025
- Make it easy to add control images to the samples in the ui
### Jul 11, 2025
- Added better video config settings to the UI for video models.
- Added Wan I2V training to the UI

13
ui/package-lock.json generated
View File

@@ -24,6 +24,7 @@
"react-icons": "^5.5.0",
"react-select": "^5.10.1",
"sqlite3": "^5.1.7",
"uuid": "^11.1.0",
"yaml": "^2.7.0"
},
"devDependencies": {
@@ -5370,6 +5371,18 @@
"resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz",
"integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw=="
},
"node_modules/uuid": {
"version": "11.1.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-11.1.0.tgz",
"integrity": "sha512-0/A9rDy9P7cJ+8w1c9WD9V//9Wj15Ce2MPz8Ri6032usz+NfePxx5AcN3bN+r6ZL6jEo066/yNYB3tn4pQEx+A==",
"funding": [
"https://github.com/sponsors/broofa",
"https://github.com/sponsors/ctavan"
],
"bin": {
"uuid": "dist/esm/bin/uuid"
}
},
"node_modules/v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",

View File

@@ -28,6 +28,7 @@
"react-icons": "^5.5.0",
"react-select": "^5.10.1",
"sqlite3": "^5.1.7",
"uuid": "^11.1.0",
"yaml": "^2.7.0"
},
"devDependencies": {

View File

@@ -2,7 +2,7 @@
import { NextRequest, NextResponse } from 'next/server';
import fs from 'fs';
import path from 'path';
import { getDatasetsRoot, getTrainingFolder } from '@/server/settings';
import { getDatasetsRoot, getTrainingFolder, getDataRoot } from '@/server/settings';
export async function GET(request: NextRequest, { params }: { params: { imagePath: string } }) {
const { imagePath } = await params;
@@ -13,8 +13,9 @@ export async function GET(request: NextRequest, { params }: { params: { imagePat
// Get allowed directories
const datasetRoot = await getDatasetsRoot();
const trainingRoot = await getTrainingFolder();
const dataRoot = await getDataRoot();
const allowedDirs = [datasetRoot, trainingRoot];
const allowedDirs = [datasetRoot, trainingRoot, dataRoot];
// Security check: Ensure path is in allowed directory
const isAllowed = allowedDirs.some(allowedDir => filepath.startsWith(allowedDir)) && !filepath.includes('..');

View File

@@ -0,0 +1,58 @@
// src/app/api/datasets/upload/route.ts
import { NextRequest, NextResponse } from 'next/server';
import { writeFile, mkdir } from 'fs/promises';
import { join } from 'path';
import { getDataRoot } from '@/server/settings';
import {v4 as uuidv4} from 'uuid';
export async function POST(request: NextRequest) {
try {
const dataRoot = await getDataRoot();
if (!dataRoot) {
return NextResponse.json({ error: 'Data root path not found' }, { status: 500 });
}
const imgRoot = join(dataRoot, 'images');
const formData = await request.formData();
const files = formData.getAll('files');
if (!files || files.length === 0) {
return NextResponse.json({ error: 'No files provided' }, { status: 400 });
}
// make it recursive if it doesn't exist
await mkdir(imgRoot, { recursive: true });
const savedFiles = await Promise.all(
files.map(async (file: any) => {
const bytes = await file.arrayBuffer();
const buffer = Buffer.from(bytes);
const extension = file.name.split('.').pop() || 'jpg';
// Clean filename and ensure it's unique
const fileName = `${uuidv4()}`; // Use UUID for unique file names
const filePath = join(imgRoot, `${fileName}.${extension}`);
await writeFile(filePath, buffer);
return filePath;
}),
);
return NextResponse.json({
message: 'Files uploaded successfully',
files: savedFiles,
});
} catch (error) {
console.error('Upload error:', error);
return NextResponse.json({ error: 'Error uploading files' }, { status: 500 });
}
}
// Increase payload size limit (default is 4mb)
export const config = {
api: {
bodyParser: false,
responseLimit: '50mb',
},
};

View File

@@ -5,6 +5,7 @@ import YAML from 'yaml';
import Editor, { OnMount } from '@monaco-editor/react';
import type { editor } from 'monaco-editor';
import { Settings } from '@/hooks/useSettings';
import { migrateJobConfig } from './jobConfig';
type Props = {
jobConfig: JobConfig;
@@ -115,6 +116,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
} catch (e) {
console.warn(e);
}
migrateJobConfig(parsed);
setJobConfig(parsed);
}
} catch (e) {

View File

@@ -7,6 +7,7 @@ import { objectCopy } from '@/utils/basic';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import Card from '@/components/Card';
import { X } from 'lucide-react';
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
type Props = {
jobConfig: JobConfig;
@@ -116,6 +117,17 @@ export default function SimpleJob({
return newDataset;
});
setJobConfig(datasets, 'config.process[0].datasets');
// update samples
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
const newSample = objectCopy(sample);
if (!hasSampleCtrlImg) {
delete newSample.ctrl_img; // remove ctrl_img if not applicable
}
return newSample;
});
setJobConfig(samples, 'config.process[0].sample.samples');
}}
options={groupedModelOptions}
/>
@@ -648,32 +660,58 @@ export default function SimpleJob({
</FormGroup>
</div>
</div>
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`} className="pt-2">
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div className="text-sm text-gray-100 mb-2 py-2 px-4 bg-yellow-700 rounded-lg">
<p className="font-semibold mb-1">Control Images</p>
To use control images on samples, add --ctrl_img to the prompts below.
<br />
Example: <code className="bg-yellow-900 p-1">make this a cartoon --ctrl_img /path/to/image.png</code>
</div>
)}
{jobConfig.config.process[0].sample.prompts.map((prompt, i) => (
<div key={i} className="flex items-center space-x-2">
<FormGroup label={`Sample Prompts (${jobConfig.config.process[0].sample.samples.length})`} className="pt-2">
<div></div>
</FormGroup>
{jobConfig.config.process[0].sample.samples.map((sample, i) => (
<div key={i} className="rounded-lg pl-4 pr-1 mb-4 bg-gray-950">
<div className="flex items-center space-x-2">
<div className="flex-1">
<TextInput
value={prompt}
onChange={value => setJobConfig(value, `config.process[0].sample.prompts[${i}]`)}
placeholder="Enter prompt"
required
/>
<div className="flex">
<div className="flex-1">
<TextInput
label={`Prompt`}
value={sample.prompt}
onChange={value => setJobConfig(value, `config.process[0].sample.samples[${i}].prompt`)}
placeholder="Enter prompt"
required
/>
</div>
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
<div
className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors"
style={{
backgroundImage: sample.ctrl_img
? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})`
: 'none',
backgroundSize: 'cover',
backgroundPosition: 'center',
marginBottom: '-1rem',
}}
onClick={() => {
openAddImageModal(imagePath => {
console.log('Selected image path:', imagePath);
if (!imagePath) return;
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`);
});
}}
>
{!sample.ctrl_img && (
<div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div>
)}
</div>
)}
</div>
<div className="pb-4"></div>
</div>
<div>
<button
type="button"
onClick={() =>
setJobConfig(
jobConfig.config.process[0].sample.prompts.filter((_, index) => index !== i),
'config.process[0].sample.prompts',
jobConfig.config.process[0].sample.samples.filter((_, index) => index !== i),
'config.process[0].sample.samples',
)
}
className="rounded-full p-1 text-sm"
@@ -682,23 +720,27 @@ export default function SimpleJob({
</button>
</div>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig([...jobConfig.config.process[0].sample.prompts, ''], 'config.process[0].sample.prompts')
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Prompt
</button>
</FormGroup>
</div>
))}
<button
type="button"
onClick={() =>
setJobConfig(
[...jobConfig.config.process[0].sample.samples, { prompt: '' }],
'config.process[0].sample.samples',
)
}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Prompt
</button>
</Card>
</div>
{status === 'success' && <p className="text-green-500 text-center">Training saved successfully!</p>}
{status === 'error' && <p className="text-red-500 text-center">Error saving training. Please try again.</p>}
</form>
<AddSingleImageModal />
</>
);
}

View File

@@ -90,17 +90,37 @@ export const defaultJobConfig: JobConfig = {
sample_every: 250,
width: 1024,
height: 1024,
prompts: [
'woman with red hair, playing chess at the park, bomb going off in the background',
'a woman holding a coffee cup, in a beanie, sitting at a cafe',
'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
'a bear building a log cabin in the snow covered mountains',
'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
'hipster man with a beard, building a chair, in a wood shop',
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
"a man holding a sign that says, 'this is a sign'",
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
samples: [
{
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
},
{
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
},
{
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
},
{
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
},
{
prompt: 'a bear building a log cabin in the snow covered mountains',
},
{
prompt: 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',
},
{
prompt: 'hipster man with a beard, building a chair, in a wood shop',
},
{
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
},
{
prompt: "a man holding a sign that says, 'this is a sign'",
},
{
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
},
],
neg: '',
seed: 42,
@@ -118,3 +138,23 @@ export const defaultJobConfig: JobConfig = {
version: '1.0',
},
};
export const migrateJobConfig = (jobConfig: JobConfig): JobConfig => {
// upgrade prompt strings to samples
if (
jobConfig?.config?.process &&
jobConfig.config.process[0]?.sample &&
Array.isArray(jobConfig.config.process[0].sample.prompts) &&
jobConfig.config.process[0].sample.prompts.length > 0
) {
let newSamples = [];
for (const prompt of jobConfig.config.process[0].sample.prompts) {
newSamples.push({
prompt: prompt,
});
}
jobConfig.config.process[0].sample.samples = newSamples;
delete jobConfig.config.process[0].sample.prompts;
}
return jobConfig;
};

View File

@@ -2,11 +2,11 @@
import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
import { SelectInput} from '@/components/formInputs';
import { SelectInput } from '@/components/formInputs';
import useSettings from '@/hooks/useSettings';
import useGPUInfo from '@/hooks/useGPUInfo';
import useDatasetList from '@/hooks/useDatasetList';
@@ -61,7 +61,7 @@ export default function TrainingForm() {
.then(data => {
console.log('Training:', data);
setGpuIDs(data.gpu_ids);
setJobConfig(JSON.parse(data.job_config));
setJobConfig(migrateJobConfig(JSON.parse(data.job_config)));
})
.catch(error => console.error('Error fetching training:', error));
}
@@ -181,11 +181,13 @@ export default function TrainingForm() {
</div>
) : (
<MainContent>
<ErrorBoundary fallback={
<div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg">
Advanced job detected. Please switch to advanced view to continue.
</div>
}>
<ErrorBoundary
fallback={
<div className="flex items-center justify-center h-64 text-lg text-red-600 font-medium bg-red-100 dark:bg-red-900/20 dark:text-red-400 border border-red-300 dark:border-red-700 rounded-lg">
Advanced job detected. Please switch to advanced view to continue.
</div>
}
>
<SimpleJob
jobConfig={jobConfig}
setJobConfig={setJobConfig}
@@ -204,4 +206,4 @@ export default function TrainingForm() {
)}
</>
);
}
}

View File

@@ -0,0 +1,141 @@
'use client';
import { createGlobalState } from 'react-global-hooks';
import { Dialog, DialogBackdrop, DialogPanel, DialogTitle } from '@headlessui/react';
import { FaUpload } from 'react-icons/fa';
import { useCallback, useState } from 'react';
import { useDropzone } from 'react-dropzone';
import { apiClient } from '@/utils/api';
export interface AddSingleImageModalState {
onComplete?: (imagePath: string|null) => void;
}
export const addSingleImageModalState = createGlobalState<AddSingleImageModalState | null>(null);
export const openAddImageModal = (onComplete: (imagePath: string|null) => void) => {
addSingleImageModalState.set({onComplete });
};
export default function AddSingleImageModal() {
const [addSingleImageModalInfo, setAddSingleImageModalInfo] = addSingleImageModalState.use();
const [uploadProgress, setUploadProgress] = useState<number>(0);
const [isUploading, setIsUploading] = useState<boolean>(false);
const open = addSingleImageModalInfo !== null;
const onCancel = () => {
if (!isUploading) {
setAddSingleImageModalInfo(null);
}
};
const onDone = (imagePath: string|null) => {
if (addSingleImageModalInfo?.onComplete && !isUploading) {
addSingleImageModalInfo.onComplete(imagePath);
setAddSingleImageModalInfo(null);
}
};
const onDrop = useCallback(
async (acceptedFiles: File[]) => {
if (acceptedFiles.length === 0) return;
setIsUploading(true);
setUploadProgress(0);
const formData = new FormData();
acceptedFiles.forEach(file => {
formData.append('files', file);
});
try {
const resp = await apiClient.post(`/api/img/upload`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
onUploadProgress: progressEvent => {
const percentCompleted = Math.round((progressEvent.loaded * 100) / (progressEvent.total || 100));
setUploadProgress(percentCompleted);
},
timeout: 0, // Disable timeout
});
console.log('Upload successful:', resp.data);
onDone(resp.data.files[0] || null);
} catch (error) {
console.error('Upload failed:', error);
} finally {
setIsUploading(false);
setUploadProgress(0);
}
},
[addSingleImageModalInfo],
);
const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop,
accept: {
'image/*': ['.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp'],
},
multiple: false,
});
return (
<Dialog open={open} onClose={onCancel} className="relative z-10">
<DialogBackdrop
transition
className="fixed inset-0 bg-gray-900/75 transition-opacity data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in"
/>
<div className="fixed inset-0 z-10 w-screen overflow-y-auto">
<div className="flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<DialogPanel
transition
className="relative transform overflow-hidden rounded-lg bg-gray-800 text-left shadow-xl transition-all data-closed:translate-y-4 data-closed:opacity-0 data-enter:duration-300 data-enter:ease-out data-leave:duration-200 data-leave:ease-in sm:my-8 sm:w-full sm:max-w-lg data-closed:sm:translate-y-0 data-closed:sm:scale-95"
>
<div className="bg-gray-800 px-4 pt-5 pb-4 sm:p-6 sm:pb-4">
<div className="text-center">
<DialogTitle as="h3" className="text-base font-semibold text-gray-200 mb-4">
Add Control Image
</DialogTitle>
<div className="w-full">
<div
{...getRootProps()}
className={`h-40 w-full flex flex-col items-center justify-center border-2 border-dashed rounded-lg cursor-pointer transition-colors duration-200
${isDragActive ? 'border-blue-500 bg-blue-50/10' : 'border-gray-600'}`}
>
<input {...getInputProps()} />
<FaUpload className="size-8 mb-3 text-gray-400" />
<p className="text-sm text-gray-200 text-center">
{isDragActive ? 'Drop the image here...' : 'Drag & drop an image here, or click to select one'}
</p>
</div>
{isUploading && (
<div className="mt-4">
<div className="w-full bg-gray-700 rounded-full h-2.5">
<div className="bg-blue-600 h-2.5 rounded-full" style={{ width: `${uploadProgress}%` }}></div>
</div>
<p className="text-sm text-gray-300 mt-2 text-center">Uploading... {uploadProgress}%</p>
</div>
)}
</div>
</div>
</div>
<div className="bg-gray-700 px-4 py-3 sm:flex sm:flex-row-reverse sm:px-6">
<button
type="button"
data-autofocus
onClick={onCancel}
disabled={isUploading}
className={`mt-3 inline-flex w-full justify-center rounded-md bg-gray-800 px-3 py-2 text-sm font-semibold text-gray-200 hover:bg-gray-800 sm:mt-0 sm:w-auto ring-0
${isUploading ? 'opacity-50 cursor-not-allowed' : ''}`}
>
Cancel
</button>
</div>
</DialogPanel>
</div>
</div>
</Dialog>
);
}

View File

@@ -14,7 +14,11 @@ export default function SampleImages({ job }: SampleImagesProps) {
if (job?.job_config) {
const jobConfig = JSON.parse(job.job_config) as JobConfig;
const sampleConfig = jobConfig.config.process[0].sample;
return sampleConfig.prompts.length;
if (sampleConfig.prompts) {
return sampleConfig.prompts.length;
} else {
return sampleConfig.samples.length;
}
}
return 10;
}, [job]);

View File

@@ -2,3 +2,4 @@ import path from 'path';
export const TOOLKIT_ROOT = path.resolve('@', '..', '..');
export const defaultTrainFolder = path.join(TOOLKIT_ROOT, 'output');
export const defaultDatasetsFolder = path.join(TOOLKIT_ROOT, 'datasets');
export const defaultDataRoot = path.join(TOOLKIT_ROOT, 'data');

View File

@@ -1,5 +1,5 @@
import { PrismaClient } from '@prisma/client';
import { defaultDatasetsFolder } from '@/paths';
import { defaultDatasetsFolder, defaultDataRoot } from '@/paths';
import { defaultTrainFolder } from '@/paths';
import NodeCache from 'node-cache';
@@ -66,3 +66,22 @@ export const getHFToken = async () => {
myCache.set(key, token);
return token;
};
export const getDataRoot = async () => {
const key = 'DATA_ROOT';
let dataRoot = myCache.get(key) as string;
if (dataRoot) {
return dataRoot;
}
let row = await prisma.settings.findFirst({
where: {
key: key,
},
});
dataRoot = defaultDataRoot;
if (row?.value && row.value !== '') {
dataRoot = row.value;
}
myCache.set(key, dataRoot);
return dataRoot;
};

View File

@@ -133,12 +133,27 @@ export interface ModelConfig {
model_kwargs: { [key: string]: any };
}
export interface SampleItem {
prompt: string;
width?: number
height?: number;
neg?: string;
seed?: number;
guidance_scale?: number;
sample_steps?: number;
fps?: number;
num_frames?: number;
ctrl_img?: string | null;
ctrl_idx?: number;
}
export interface SampleConfig {
sampler: string;
sample_every: number;
width: number;
height: number;
prompts: string[];
prompts?: string[];
samples: SampleItem[];
neg: string;
seed: number;
walk_seed: boolean;

View File

@@ -1 +1 @@
VERSION = "0.3.7"
VERSION = "0.3.8"