Add advanced config section to the captioner

This commit is contained in:
Jaret Burkett
2026-04-28 10:49:46 -06:00
parent f972b750e6
commit 43989cc19e
4 changed files with 264 additions and 219 deletions

View File

@@ -17,7 +17,7 @@ import { TopBar, MainContent } from '@/components/layout';
import { Button } from '@headlessui/react';
import { FaChevronLeft } from 'react-icons/fa';
import SimpleJob from './SimpleJob';
import AdvancedJob from './AdvancedJob';
import AdvancedConfigEditor from '@/components/AdvancedConfigEditor';
import ErrorBoundary from '@/components/ErrorBoundary';
import { apiClient } from '@/utils/api';
@@ -279,17 +279,20 @@ export default function TrainingForm() {
{showAdvancedView ? (
<div className="pt-[48px] absolute top-0 left-0 w-full h-full overflow-auto">
<AdvancedJob
jobConfig={jobConfig}
setJobConfig={setJobConfig}
status={status}
handleSubmit={handleSubmit}
runId={runId}
gpuIDs={gpuIDs}
setGpuIDs={setGpuIDs}
gpuList={gpuList}
datasetOptions={datasetOptions}
settings={settings}
<AdvancedConfigEditor
config={jobConfig}
setConfig={setJobConfig}
transformOnParse={(parsed: any) => {
try {
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
parsed.config.process[0].device = 'cuda';
parsed.config.process[0].performance_log_every = 10;
} catch (e) {
console.warn(e);
}
return migrateJobConfig(parsed);
}}
/>
</div>
) : (

View File

@@ -1,28 +1,16 @@
'use client';
import { useEffect, useState, useRef } from 'react';
import { JobConfig } from '@/types';
import YAML from 'yaml';
import Editor, { OnMount } from '@monaco-editor/react';
import type { editor } from 'monaco-editor';
import { Settings } from '@/hooks/useSettings';
import { migrateJobConfig } from './jobConfig';
import { useTheme } from '@/components/ThemeProvider';
type Props = {
jobConfig: JobConfig;
setJobConfig: (value: any, key?: string) => void;
status: 'idle' | 'saving' | 'success' | 'error';
handleSubmit: (event: React.FormEvent<HTMLFormElement>) => void;
runId: string | null;
gpuIDs: string | null;
setGpuIDs: (value: string | null) => void;
gpuList: any;
datasetOptions: any;
settings: Settings;
type Props<T> = {
config: T;
setConfig: (value: any, key?: string) => void;
transformOnParse?: (parsed: any) => any;
};
const isDev = process.env.NODE_ENV === 'development';
const yamlConfig: YAML.DocumentOptions &
YAML.SchemaOptions &
YAML.ParseOptions &
@@ -47,11 +35,11 @@ function toYaml(obj: any): string {
return doc.toString(yamlConfig);
}
export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props) {
export default function AdvancedConfigEditor<T>({ config, setConfig, transformOnParse }: Props<T>) {
const { theme } = useTheme();
const [editorValue, setEditorValue] = useState<string>('');
const [hasError, setHasError] = useState(false);
const lastJobConfigUpdateStringRef = useRef('');
const lastConfigUpdateStringRef = useRef('');
const editorRef = useRef<editor.IStandaloneCodeEditor | null>(null);
const monacoRef = useRef<any>(null);
@@ -66,17 +54,17 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
// Initial content setup
try {
const yamlContent = toYaml(jobConfig);
const yamlContent = toYaml(config);
setEditorValue(yamlContent);
lastJobConfigUpdateStringRef.current = JSON.stringify(jobConfig);
lastConfigUpdateStringRef.current = JSON.stringify(config);
} catch (e) {
console.warn(e);
}
};
useEffect(() => {
const lastUpdate = lastJobConfigUpdateStringRef.current;
const currentUpdate = JSON.stringify(jobConfig);
const lastUpdate = lastConfigUpdateStringRef.current;
const currentUpdate = JSON.stringify(config);
// Skip if no changes or editor not yet mounted
if (lastUpdate === currentUpdate || !isEditorMounted.current) {
@@ -93,7 +81,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
const scrollTop = editor.getScrollTop();
// Update content
const yamlContent = toYaml(jobConfig);
const yamlContent = toYaml(config);
// Only update if the content is actually different
if (yamlContent !== editor.getValue()) {
@@ -106,12 +94,12 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
editor.setScrollTop(scrollTop);
}
lastJobConfigUpdateStringRef.current = currentUpdate;
lastConfigUpdateStringRef.current = currentUpdate;
}
} catch (e) {
console.warn(e);
}
}, [jobConfig]);
}, [config]);
const setMarkers = (errors: { message: string; line: number }[]) => {
const monaco = monacoRef.current;
@@ -132,27 +120,18 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
if (value === undefined) return;
try {
const parsed = YAML.parse(value);
let parsed = YAML.parse(value);
setHasError(false);
setMarkers([]);
// Don't update jobConfig if the change came from the editor itself
// Don't update config if the change came from the editor itself
// to avoid a circular update loop
if (JSON.stringify(parsed) !== lastJobConfigUpdateStringRef.current) {
lastJobConfigUpdateStringRef.current = JSON.stringify(parsed);
// We have to ensure certain things are always set
try {
// parsed.config.process[0].type = 'ui_trainer';
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
parsed.config.process[0].device = 'cuda';
parsed.config.process[0].performance_log_every = 10;
} catch (e) {
console.warn(e);
if (JSON.stringify(parsed) !== lastConfigUpdateStringRef.current) {
if (transformOnParse) {
parsed = transformOnParse(parsed);
}
migrateJobConfig(parsed);
setJobConfig(parsed);
lastConfigUpdateStringRef.current = JSON.stringify(parsed);
setConfig(parsed);
}
} catch (e: any) {
setHasError(true);
@@ -175,6 +154,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
defaultLanguage="yaml"
value={editorValue}
theme={theme === 'dark' ? 'vs-dark' : 'light'}
className="z-0"
onChange={handleChange}
onMount={handleEditorDidMount}
options={{

View File

@@ -2,32 +2,19 @@ import React, { useState, useEffect, useRef } from 'react';
import { Modal } from '@/components/Modal';
import { createGlobalState } from 'react-global-hooks';
import { useFromNull } from '@/hooks/useFromNull';
import {
Checkbox,
CreatableSelectInput,
FormGroup,
SelectInput,
TextAreaInput,
TextInput,
} from '@/components/formInputs';
import { CaptionJobConfig } from '@/types';
import { defaultCaptionJobConfig, handleCaptionerTypeChange } from '@/helpers/captionJobConfig';
import { defaultCaptionJobConfig } from '@/helpers/captionJobConfig';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
import {
captionerTypes,
defaultQtype,
groupedCaptionerTypes,
maxNewTokensOptions,
maxResOptions,
quantizationOptions,
} from '@/helpers/captionOptions';
import { isMac } from '@/helpers/basic';
import useGPUInfo from '@/hooks/useGPUInfo';
import { apiClient } from '@/utils/api';
import { v4 as uuidv4 } from 'uuid';
import { startJob } from '@/utils/jobs';
import { startQueue } from '@/utils/queue';
import CaptionSimpleJob from '@/components/CaptionSimpleJob';
import AdvancedConfigEditor from '@/components/AdvancedConfigEditor';
import { SelectInput } from '@/components/formInputs';
export interface CaptionDatasetModalState {
datasetPath: string;
@@ -45,6 +32,7 @@ export const CaptionDatasetModal: React.FC = () => {
const [jobConfig, setJobConfig] = useNestedState<CaptionJobConfig>(objectCopy(defaultCaptionJobConfig));
const [gpuIDs, setGpuIDs] = useState<string | null>(null);
const { gpuList, isGPUInfoLoaded } = useGPUInfo();
const [activeTab, setActiveTab] = useState<'simple' | 'advanced'>('simple');
const open = modalInfo !== null;
const isSavingRef = useRef(false);
const showGPUSelect = !isMac();
@@ -52,6 +40,7 @@ export const CaptionDatasetModal: React.FC = () => {
useFromNull(() => {
// reset the state
setJobConfig(objectCopy(defaultCaptionJobConfig));
setActiveTab('simple');
// set the path_to_caption
if (modalInfo?.datasetPath) {
setJobConfig(modalInfo.datasetPath, 'config.process[0].caption.path_to_caption');
@@ -73,8 +62,6 @@ export const CaptionDatasetModal: React.FC = () => {
setModalInfo(null);
};
const selectedCaptionOption = captionerTypes.find(option => option.name === jobConfig.config.process[0].type);
const saveJob = async () => {
if (isSavingRef.current) return;
if (!modalInfo?.datasetPath) {
@@ -111,163 +98,54 @@ export const CaptionDatasetModal: React.FC = () => {
});
};
const additionalSections = selectedCaptionOption?.additionalSections || [];
const tabButtonClass = (tab: 'simple' | 'advanced') =>
`px-4 py-2 text-sm font-medium border-b-2 transition-colors ${
activeTab === tab
? 'border-blue-500 text-blue-400'
: 'border-transparent text-gray-400 hover:text-gray-200 hover:border-gray-600'
}`;
return (
<Modal isOpen={open} onClose={handleClose} title="Caption Dataset" size="lg">
<Modal isOpen={open} onClose={handleClose} title="Caption Dataset" size={activeTab === 'advanced' ? 'xl' : 'lg'}>
<div className="space-y-4 text-gray-200">
<div className="flex items-center border-b border-gray-700 -mt-2">
<button type="button" className={tabButtonClass('simple')} onClick={() => setActiveTab('simple')}>
Simple
</button>
<button type="button" className={tabButtonClass('advanced')} onClick={() => setActiveTab('advanced')}>
Advanced
</button>
<div className="flex-1" />
{activeTab === 'advanced' && showGPUSelect && (
<div className="pb-2">
<SelectInput
value={`${gpuIDs}`}
onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
</div>
)}
</div>
<form
onSubmit={e => {
e.preventDefault();
saveJob();
}}
>
<div className="text-sm text-gray-400">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mt-4">
<div>
<SelectInput
label="Captioner Type"
value={jobConfig.config.process[0].type}
onChange={value => {
handleCaptionerTypeChange(jobConfig.config.process[0].type, value, jobConfig, setJobConfig);
}}
options={groupedCaptionerTypes}
/>
</div>
{showGPUSelect && (
<div>
<SelectInput
label="GPU ID"
value={`${gpuIDs}`}
onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
</div>
)}
{activeTab === 'simple' ? (
<CaptionSimpleJob
jobConfig={jobConfig}
setJobConfig={setJobConfig}
gpuIDs={gpuIDs}
setGpuIDs={setGpuIDs}
gpuList={gpuList}
showGPUSelect={showGPUSelect}
/>
) : (
<div className="h-[60vh] mt-2">
<AdvancedConfigEditor config={jobConfig} setConfig={setJobConfig} />
</div>
<div className="mt-4">
<CreatableSelectInput
label="Name or Path"
value={jobConfig.config.process[0].caption.model_name_or_path}
docKey="config.process[0].caption.model_name_or_path"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].caption.model_name_or_path');
}}
placeholder=""
options={selectedCaptionOption?.name_or_path_options || []}
required
/>
</div>
{additionalSections.includes('caption.model_name_or_path2') && (
<div className="mt-4">
<CreatableSelectInput
label="Name or Path 2"
value={jobConfig.config.process[0].caption.model_name_or_path2 || ''}
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].caption.model_name_or_path2');
}}
placeholder=""
options={selectedCaptionOption?.name_or_path2_options || []}
/>
</div>
)}
{additionalSections.includes('caption.fixed_caption') && (
<div className="mt-4">
<TextInput
label="Fixed Caption"
value={jobConfig.config.process[0].caption.fixed_caption || ''}
onChange={value => {
if (value?.trim() === '') {
//@ts-ignore
value = undefined;
}
setJobConfig(value, 'config.process[0].caption.fixed_caption');
}}
placeholder="Enter fixed caption (if you want the same caption for all audio files)"
/>
</div>
)}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mt-4">
<div>
<SelectInput
label="Quantize"
value={jobConfig.config.process[0].caption.quantize ? jobConfig.config.process[0].caption.qtype : ''}
onChange={value => {
if (value === '') {
setJobConfig(false, 'config.process[0].caption.quantize');
value = defaultQtype;
} else {
setJobConfig(true, 'config.process[0].caption.quantize');
}
setJobConfig(value, 'config.process[0].caption.qtype');
}}
options={quantizationOptions}
/>
{additionalSections.includes('caption.max_res') && (
<div className="mt-4">
<SelectInput
label="Max Resolution"
value={`${jobConfig.config.process[0].caption.max_res || ''}`}
onChange={value => {
const intVal = parseInt(value);
if (!isNaN(intVal)) {
setJobConfig(intVal, 'config.process[0].caption.max_res');
}
}}
options={maxResOptions}
/>
</div>
)}
{additionalSections.includes('caption.max_new_tokens') && (
<div className="mt-4">
<SelectInput
label="Max New Tokens"
value={`${jobConfig.config.process[0].caption.max_new_tokens || ''}`}
onChange={value => {
const intVal = parseInt(value);
if (!isNaN(intVal)) {
setJobConfig(intVal, 'config.process[0].caption.max_new_tokens');
}
}}
options={maxNewTokensOptions}
/>
</div>
)}
</div>
<div>
<FormGroup label="Options">
<Checkbox
label="Low VRAM"
checked={jobConfig.config.process[0].caption.low_vram}
onChange={value => setJobConfig(value, 'config.process[0].caption.low_vram')}
/>
<Checkbox
label="Recaption"
checked={jobConfig.config.process[0].caption.recaption}
onChange={value => setJobConfig(value, 'config.process[0].caption.recaption')}
/>
</FormGroup>
</div>
</div>
{additionalSections.includes('caption.caption_prompt') && (
<div className="mt-4">
<TextAreaInput
label="Caption Prompt"
value={jobConfig.config.process[0].caption.caption_prompt || ''}
onChange={value => {
setJobConfig(value, 'config.process[0].caption.caption_prompt');
}}
placeholder="Enter caption prompt"
/>
</div>
)}
</div>
)}
<div className="mt-6 flex justify-end space-x-3">
<button

View File

@@ -0,0 +1,184 @@
import React from 'react';
import {
Checkbox,
CreatableSelectInput,
FormGroup,
SelectInput,
TextAreaInput,
TextInput,
} from '@/components/formInputs';
import { CaptionJobConfig } from '@/types';
import { handleCaptionerTypeChange } from '@/helpers/captionJobConfig';
import {
captionerTypes,
defaultQtype,
groupedCaptionerTypes,
maxNewTokensOptions,
maxResOptions,
quantizationOptions,
} from '@/helpers/captionOptions';
type Props = {
jobConfig: CaptionJobConfig;
setJobConfig: (value: any, key?: string) => void;
gpuIDs: string | null;
setGpuIDs: (value: string | null) => void;
gpuList: any;
showGPUSelect: boolean;
};
const CaptionSimpleJob: React.FC<Props> = ({ jobConfig, setJobConfig, gpuIDs, setGpuIDs, gpuList, showGPUSelect }) => {
const selectedCaptionOption = captionerTypes.find(option => option.name === jobConfig.config.process[0].type);
const additionalSections = selectedCaptionOption?.additionalSections || [];
return (
<div className="text-sm text-gray-400">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mt-4">
<div>
<SelectInput
label="Captioner Type"
value={jobConfig.config.process[0].type}
onChange={value => {
handleCaptionerTypeChange(jobConfig.config.process[0].type, value, jobConfig, setJobConfig);
}}
options={groupedCaptionerTypes}
/>
</div>
{showGPUSelect && (
<div>
<SelectInput
label="GPU ID"
value={`${gpuIDs}`}
onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
</div>
)}
</div>
<div className="mt-4">
<CreatableSelectInput
label="Name or Path"
value={jobConfig.config.process[0].caption.model_name_or_path}
docKey="config.process[0].caption.model_name_or_path"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].caption.model_name_or_path');
}}
placeholder=""
options={selectedCaptionOption?.name_or_path_options || []}
required
/>
</div>
{additionalSections.includes('caption.model_name_or_path2') && (
<div className="mt-4">
<CreatableSelectInput
label="Name or Path 2"
value={jobConfig.config.process[0].caption.model_name_or_path2 || ''}
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].caption.model_name_or_path2');
}}
placeholder=""
options={selectedCaptionOption?.name_or_path2_options || []}
/>
</div>
)}
{additionalSections.includes('caption.fixed_caption') && (
<div className="mt-4">
<TextInput
label="Fixed Caption"
value={jobConfig.config.process[0].caption.fixed_caption || ''}
onChange={value => {
if (value?.trim() === '') {
//@ts-ignore
value = undefined;
}
setJobConfig(value, 'config.process[0].caption.fixed_caption');
}}
placeholder="Enter fixed caption (if you want the same caption for all audio files)"
/>
</div>
)}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4 mt-4">
<div>
<SelectInput
label="Quantize"
value={jobConfig.config.process[0].caption.quantize ? jobConfig.config.process[0].caption.qtype : ''}
onChange={value => {
if (value === '') {
setJobConfig(false, 'config.process[0].caption.quantize');
value = defaultQtype;
} else {
setJobConfig(true, 'config.process[0].caption.quantize');
}
setJobConfig(value, 'config.process[0].caption.qtype');
}}
options={quantizationOptions}
/>
{additionalSections.includes('caption.max_res') && (
<div className="mt-4">
<SelectInput
label="Max Resolution"
value={`${jobConfig.config.process[0].caption.max_res || ''}`}
onChange={value => {
const intVal = parseInt(value);
if (!isNaN(intVal)) {
setJobConfig(intVal, 'config.process[0].caption.max_res');
}
}}
options={maxResOptions}
/>
</div>
)}
{additionalSections.includes('caption.max_new_tokens') && (
<div className="mt-4">
<SelectInput
label="Max New Tokens"
value={`${jobConfig.config.process[0].caption.max_new_tokens || ''}`}
onChange={value => {
const intVal = parseInt(value);
if (!isNaN(intVal)) {
setJobConfig(intVal, 'config.process[0].caption.max_new_tokens');
}
}}
options={maxNewTokensOptions}
/>
</div>
)}
</div>
<div>
<FormGroup label="Options">
<Checkbox
label="Low VRAM"
checked={jobConfig.config.process[0].caption.low_vram}
onChange={value => setJobConfig(value, 'config.process[0].caption.low_vram')}
/>
<Checkbox
label="Recaption"
checked={jobConfig.config.process[0].caption.recaption}
onChange={value => setJobConfig(value, 'config.process[0].caption.recaption')}
/>
</FormGroup>
</div>
</div>
{additionalSections.includes('caption.caption_prompt') && (
<div className="mt-4">
<TextAreaInput
label="Caption Prompt"
value={jobConfig.config.process[0].caption.caption_prompt || ''}
onChange={value => {
setJobConfig(value, 'config.process[0].caption.caption_prompt');
}}
placeholder="Enter caption prompt"
/>
</div>
)}
</div>
);
};
export default CaptionSimpleJob;