Added support for new concept slider training script to CLI and UI

This commit is contained in:
Jaret Burkett
2025-09-16 10:22:34 -06:00
parent 3666b112a8
commit 218f673e3d
13 changed files with 996 additions and 78 deletions

View File

@@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
// We have to ensure certain things are always set
try {
parsed.config.process[0].type = 'ui_trainer';
// parsed.config.process[0].type = 'ui_trainer';
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
parsed.config.process[0].device = 'cuda';

View File

@@ -1,6 +1,13 @@
'use client';
import { useMemo } from 'react';
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
import {
modelArchs,
ModelArch,
groupedModelOptions,
quantizationOptions,
defaultQtype,
jobTypeOptions,
} from './options';
import { defaultDatasetConfig } from './jobConfig';
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
import { objectCopy } from '@/utils/basic';
@@ -8,7 +15,7 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
import Card from '@/components/Card';
import { X } from 'lucide-react';
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
import {FlipHorizontal2, FlipVertical2} from "lucide-react"
import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
type Props = {
jobConfig: JobConfig;
@@ -39,6 +46,21 @@ export default function SimpleJob({
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]);
const jobType = useMemo(() => {
return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type);
}, [jobConfig.config.process[0].type]);
const disableSections = useMemo(() => {
let sections: string[] = [];
if (modelArch?.disableSections) {
sections = sections.concat(modelArch.disableSections);
}
if (jobType?.disableSections) {
sections = sections.concat(jobType.disableSections);
}
return sections;
}, [modelArch, jobType]);
const isVideoModel = !!(modelArch?.group === 'video');
const numTopCards = useMemo(() => {
@@ -46,12 +68,14 @@ export default function SimpleJob({
if (modelArch?.additionalSections?.includes('model.multistage')) {
count += 1; // add multistage card
}
if (!modelArch?.disableSections?.includes('model.quantize')) {
if (!disableSections.includes('model.quantize')) {
count += 1; // add quantization card
}
if (!disableSections.includes('slider')) {
count += 1; // add slider card
}
return count;
}, [modelArch]);
}, [modelArch, disableSections]);
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
@@ -62,6 +86,20 @@ export default function SimpleJob({
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
}
const numTrainingCols = useMemo(() => {
let count = 4;
if (!disableSections.includes('train.diff_output_preservation')) {
count += 1;
}
return count;
}, [disableSections]);
let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6';
if (numTrainingCols == 5) {
trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6';
}
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
if (!hasARA) {
@@ -78,7 +116,7 @@ export default function SimpleJob({
let ARAs: SelectOption[] = [];
if (modelArch.accuracyRecoveryAdapters) {
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
ARAs.push({ value, label });
ARAs.push({ value, label });
}
}
if (ARAs.length > 0) {
@@ -124,19 +162,21 @@ export default function SimpleJob({
onChange={value => setGpuIDs(value)}
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
/>
<TextInput
label="Trigger Word"
value={jobConfig.config.process[0].trigger_word || ''}
docKey="config.process[0].trigger_word"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].trigger_word');
}}
placeholder=""
required
/>
{disableSections.includes('trigger_word') ? null : (
<TextInput
label="Trigger Word"
value={jobConfig.config.process[0].trigger_word || ''}
docKey="config.process[0].trigger_word"
onChange={(value: string | null) => {
if (value?.trim() === '') {
value = null;
}
setJobConfig(value, 'config.process[0].trigger_word');
}}
placeholder=""
required
/>
)}
</Card>
{/* Model Configuration Section */}
@@ -223,7 +263,7 @@ export default function SimpleJob({
</FormGroup>
)}
</Card>
{modelArch?.disableSections?.includes('model.quantize') ? null : (
{disableSections.includes('model.quantize') ? null : (
<Card title="Quantization">
<SelectInput
label="Transformer"
@@ -270,14 +310,14 @@ export default function SimpleJob({
/>
</FormGroup>
<NumberInput
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
</Card>
)}
<Card title="Target">
@@ -319,7 +359,7 @@ export default function SimpleJob({
max={1024}
required
/>
{modelArch?.disableSections?.includes('network.conv') ? null : (
{disableSections.includes('network.conv') ? null : (
<NumberInput
label="Conv Rank"
value={jobConfig.config.process[0].network.conv}
@@ -336,6 +376,38 @@ export default function SimpleJob({
</>
)}
</Card>
{!disableSections.includes('slider') && (
<Card title="Slider">
<TextInput
label="Target Class"
className=""
value={jobConfig.config.process[0].slider?.target_class ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.target_class')}
placeholder="eg. person"
/>
<TextInput
label="Positive Prompt"
className=""
value={jobConfig.config.process[0].slider?.positive_prompt ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.positive_prompt')}
placeholder="eg. person who is happy"
/>
<TextInput
label="Negative Prompt"
className=""
value={jobConfig.config.process[0].slider?.negative_prompt ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.negative_prompt')}
placeholder="eg. person who is sad"
/>
<TextInput
label="Anchor Class"
className=""
value={jobConfig.config.process[0].slider?.anchor_class ?? ''}
onChange={value => setJobConfig(value, 'config.process[0].slider.anchor_class')}
placeholder=""
/>
</Card>
)}
<Card title="Save">
<SelectInput
label="Data Type"
@@ -367,7 +439,7 @@ export default function SimpleJob({
</div>
<div>
<Card title="Training">
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
<div className={trainingBarClass}>
<div>
<NumberInput
label="Batch Size"
@@ -426,11 +498,11 @@ export default function SimpleJob({
/>
</div>
<div>
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
{disableSections.includes('train.timestep_type') ? null : (
<SelectInput
label="Timestep Type"
value={jobConfig.config.process[0].train.timestep_type}
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
disabled={disableSections.includes('train.timestep_type') || false}
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
options={[
{ value: 'sigmoid', label: 'Sigmoid' },
@@ -508,33 +580,39 @@ export default function SimpleJob({
</FormGroup>
</div>
<div>
<FormGroup label="Regularization">
<Checkbox
label="Differtial Output Preservation"
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
{jobConfig.config.process[0].train.diff_output_preservation && (
{disableSections.includes('train.diff_output_preservation') ? null : (
<>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
<FormGroup label="Regularization">
<Checkbox
label="Differential Output Preservation"
className="pt-1"
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
{jobConfig.config.process[0].train.diff_output_preservation && (
<>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
}
placeholder="eg. woman"
/>
</>
)}
</>
)}
</div>
@@ -641,12 +719,20 @@ export default function SimpleJob({
</FormGroup>
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
<Checkbox
label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>}
label={
<>
Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" />
</>
}
checked={dataset.flip_x || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
/>
<Checkbox
label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>}
label={
<>
Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" />
</>
}
checked={dataset.flip_y || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
/>
@@ -812,7 +898,7 @@ export default function SimpleJob({
onChange={value => {
setJobConfig(value, 'config.process[0].train.skip_first_sample');
// cannot do both, so disable the other
if (value){
if (value) {
setJobConfig(false, 'config.process[0].train.force_first_sample');
}
}}
@@ -827,7 +913,7 @@ export default function SimpleJob({
onChange={value => {
setJobConfig(value, 'config.process[0].train.force_first_sample');
// cannot do both, so disable the other
if (value){
if (value) {
setJobConfig(false, 'config.process[0].train.skip_first_sample');
}
}}
@@ -841,7 +927,7 @@ export default function SimpleJob({
onChange={value => {
setJobConfig(value, 'config.process[0].train.disable_sampling');
// cannot do both, so disable the other
if (value){
if (value) {
setJobConfig(false, 'config.process[0].train.force_first_sample');
}
}}
@@ -866,6 +952,113 @@ export default function SimpleJob({
placeholder="Enter prompt"
required
/>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mt-2">
<TextInput
label={`Width`}
value={sample.width ? `${sample.width}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].width;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`);
} else {
console.warn('Invalid width value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.width} (default)`}
/>
<TextInput
label={`Height`}
value={sample.height ? `${sample.height}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].height;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`);
} else {
console.warn('Invalid height value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.height} (default)`}
/>
<TextInput
label={`Seed`}
value={sample.seed ? `${sample.seed}` : ''}
onChange={value => {
// remove any non-numeric characters
value = value.replace(/\D/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].seed;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
const intValue = parseInt(value);
if (!isNaN(intValue)) {
setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`);
} else {
console.warn('Invalid seed value:', value);
}
}
}}
placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`}
/>
<TextInput
label={`LoRA Scale`}
value={sample.network_multiplier ? `${sample.network_multiplier}` : ''}
onChange={value => {
// remove any non-numeric, - or . characters
value = value.replace(/[^0-9.-]/g, '');
if (value === '') {
// remove the key from the config if empty
let newConfig = objectCopy(jobConfig);
if (newConfig.config.process[0].sample.samples[i]) {
delete newConfig.config.process[0].sample.samples[i].network_multiplier;
setJobConfig(
newConfig.config.process[0].sample.samples,
'config.process[0].sample.samples',
);
}
} else {
// set it as a string
setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`);
return;
}
}}
placeholder={`1.0 (default)`}
/>
</div>
</div>
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (

View File

@@ -1,4 +1,4 @@
import { JobConfig, DatasetConfig } from '@/types';
import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
export const defaultDatasetConfig: DatasetConfig = {
folder_path: '/path/to/images/folder',
@@ -20,13 +20,22 @@ export const defaultDatasetConfig: DatasetConfig = {
flip_y: false,
};
export const defaultSliderConfig: SliderConfig = {
guidance_strength: 3.0,
anchor_strength: 1.0,
positive_prompt: 'person who is happy',
negative_prompt: 'person who is sad',
target_class: 'person',
anchor_class: "",
};
export const defaultJobConfig: JobConfig = {
job: 'extension',
config: {
name: 'my_first_lora_v1',
process: [
{
type: 'ui_trainer',
type: 'diffusion_trainer',
training_folder: 'output',
sqlite_db_path: './aitk_db.db',
device: 'cuda',
@@ -100,7 +109,7 @@ export const defaultJobConfig: JobConfig = {
height: 1024,
samples: [
{
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background',
},
{
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
@@ -109,7 +118,8 @@ export const defaultJobConfig: JobConfig = {
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
},
{
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
prompt:
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
},
{
prompt: 'a bear building a log cabin in the snow covered mountains',
@@ -121,13 +131,15 @@ export const defaultJobConfig: JobConfig = {
prompt: 'hipster man with a beard, building a chair, in a wood shop',
},
{
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
prompt:
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
},
{
prompt: "a man holding a sign that says, 'this is a sign'",
},
{
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
prompt:
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
},
],
neg: '',

View File

@@ -1,8 +1,16 @@
import { GroupedSelectOption, SelectOption } from '@/types';
import { GroupedSelectOption, SelectOption, JobConfig } from '@/types';
import { defaultSliderConfig } from './jobConfig';
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type DisableableSections =
| 'model.quantize'
| 'train.timestep_type'
| 'network.conv'
| 'trigger_word'
| 'train.diff_output_preservation'
| 'slider';
type AdditionalSections =
| 'datasets.control_path'
| 'datasets.do_i2v'
@@ -439,3 +447,33 @@ export const quantizationOptions: SelectOption[] = [
];
export const defaultQtype = 'qfloat8';
interface JobTypeOption extends SelectOption {
disableSections?: DisableableSections[];
processSections?: string[];
onActivate?: (config: JobConfig) => JobConfig;
onDeactivate?: (config: JobConfig) => JobConfig;
}
export const jobTypeOptions: JobTypeOption[] = [
{
value: 'diffusion_trainer',
label: 'LoRA Trainer',
disableSections: ['slider'],
},
{
value: 'concept_slider',
label: 'Concept Slider',
disableSections: ['trigger_word', 'train.diff_output_preservation'],
onActivate: (config: JobConfig) => {
// add default slider config
config.config.process[0].slider = { ...defaultSliderConfig };
return config;
},
onDeactivate: (config: JobConfig) => {
// remove slider config
delete config.config.process[0].slider;
return config;
},
},
];

View File

@@ -3,6 +3,7 @@
import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
import { jobTypeOptions } from './options';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
@@ -144,6 +145,38 @@ export default function TrainingForm() {
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
</>
)}
{!showAdvancedView && (
<>
<div>
<SelectInput
value={`${jobConfig?.config.process[0].type}`}
onChange={value => {
// undo current job type changes
const currentOption = jobTypeOptions.find(
option => option.value === jobConfig?.config.process[0].type,
);
if (currentOption && currentOption.onDeactivate) {
setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig)));
}
const option = jobTypeOptions.find(option => option.value === value);
if (option) {
if (option.onActivate) {
setJobConfig(option.onActivate(objectCopy(jobConfig)));
}
jobTypeOptions.forEach(opt => {
if (opt.value !== option.value && opt.onDeactivate) {
setJobConfig(opt.onDeactivate(objectCopy(jobConfig)));
}
});
}
setJobConfig(value, 'config.process[0].type');
}}
options={jobTypeOptions}
/>
</div>
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
</>
)}
<div className="pr-2">
<Button

View File

@@ -68,8 +68,8 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: Te
TextInput.displayName = 'TextInput';
export interface NumberInputProps extends InputProps {
value: number;
onChange: (value: number) => void;
value: number | null;
onChange: (value: number | null) => void;
min?: number;
max?: number;
}

View File

@@ -143,7 +143,7 @@ export interface ModelConfig {
export interface SampleItem {
prompt: string;
width?: number
width?: number;
height?: number;
neg?: string;
seed?: number;
@@ -153,6 +153,7 @@ export interface SampleItem {
num_frames?: number;
ctrl_img?: string | null;
ctrl_idx?: number;
network_multiplier?: number;
}
export interface SampleConfig {
@@ -171,14 +172,24 @@ export interface SampleConfig {
fps: number;
}
export interface SliderConfig {
guidance_strength?: number;
anchor_strength?: number;
positive_prompt?: string;
negative_prompt?: string;
target_class?: string;
anchor_class?: string | null;
}
export interface ProcessConfig {
type: 'ui_trainer';
type: string;
sqlite_db_path?: string;
training_folder: string;
performance_log_every: number;
trigger_word: string | null;
device: string;
network?: NetworkConfig;
slider?: SliderConfig;
save: SaveConfig;
datasets: DatasetConfig[];
train: TrainConfig;