mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Added support for new concept slider training script to CLI and UI
This commit is contained in:
@@ -108,7 +108,7 @@ export default function AdvancedJob({ jobConfig, setJobConfig, settings }: Props
|
||||
|
||||
// We have to ensure certain things are always set
|
||||
try {
|
||||
parsed.config.process[0].type = 'ui_trainer';
|
||||
// parsed.config.process[0].type = 'ui_trainer';
|
||||
parsed.config.process[0].sqlite_db_path = './aitk_db.db';
|
||||
parsed.config.process[0].training_folder = settings.TRAINING_FOLDER;
|
||||
parsed.config.process[0].device = 'cuda';
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
'use client';
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch, groupedModelOptions, quantizationOptions, defaultQtype } from './options';
|
||||
import {
|
||||
modelArchs,
|
||||
ModelArch,
|
||||
groupedModelOptions,
|
||||
quantizationOptions,
|
||||
defaultQtype,
|
||||
jobTypeOptions,
|
||||
} from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -8,7 +15,7 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||
import {FlipHorizontal2, FlipVertical2} from "lucide-react"
|
||||
import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
|
||||
|
||||
type Props = {
|
||||
jobConfig: JobConfig;
|
||||
@@ -39,6 +46,21 @@ export default function SimpleJob({
|
||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||
}, [jobConfig.config.process[0].model.arch]);
|
||||
|
||||
const jobType = useMemo(() => {
|
||||
return jobTypeOptions.find(j => j.value === jobConfig.config.process[0].type);
|
||||
}, [jobConfig.config.process[0].type]);
|
||||
|
||||
const disableSections = useMemo(() => {
|
||||
let sections: string[] = [];
|
||||
if (modelArch?.disableSections) {
|
||||
sections = sections.concat(modelArch.disableSections);
|
||||
}
|
||||
if (jobType?.disableSections) {
|
||||
sections = sections.concat(jobType.disableSections);
|
||||
}
|
||||
return sections;
|
||||
}, [modelArch, jobType]);
|
||||
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
const numTopCards = useMemo(() => {
|
||||
@@ -46,12 +68,14 @@ export default function SimpleJob({
|
||||
if (modelArch?.additionalSections?.includes('model.multistage')) {
|
||||
count += 1; // add multistage card
|
||||
}
|
||||
if (!modelArch?.disableSections?.includes('model.quantize')) {
|
||||
if (!disableSections.includes('model.quantize')) {
|
||||
count += 1; // add quantization card
|
||||
}
|
||||
if (!disableSections.includes('slider')) {
|
||||
count += 1; // add slider card
|
||||
}
|
||||
return count;
|
||||
|
||||
}, [modelArch]);
|
||||
}, [modelArch, disableSections]);
|
||||
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
|
||||
@@ -62,6 +86,20 @@ export default function SimpleJob({
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
|
||||
}
|
||||
|
||||
const numTrainingCols = useMemo(() => {
|
||||
let count = 4;
|
||||
if (!disableSections.includes('train.diff_output_preservation')) {
|
||||
count += 1;
|
||||
}
|
||||
return count;
|
||||
}, [disableSections]);
|
||||
|
||||
let trainingBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6';
|
||||
|
||||
if (numTrainingCols == 5) {
|
||||
trainingBarClass = 'grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
const hasARA = modelArch?.accuracyRecoveryAdapters && Object.keys(modelArch.accuracyRecoveryAdapters).length > 0;
|
||||
if (!hasARA) {
|
||||
@@ -78,7 +116,7 @@ export default function SimpleJob({
|
||||
let ARAs: SelectOption[] = [];
|
||||
if (modelArch.accuracyRecoveryAdapters) {
|
||||
for (const [label, value] of Object.entries(modelArch.accuracyRecoveryAdapters)) {
|
||||
ARAs.push({ value, label });
|
||||
ARAs.push({ value, label });
|
||||
}
|
||||
}
|
||||
if (ARAs.length > 0) {
|
||||
@@ -124,19 +162,21 @@ export default function SimpleJob({
|
||||
onChange={value => setGpuIDs(value)}
|
||||
options={gpuList.map((gpu: any) => ({ value: `${gpu.index}`, label: `GPU #${gpu.index}` }))}
|
||||
/>
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
docKey="config.process[0].trigger_word"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
{disableSections.includes('trigger_word') ? null : (
|
||||
<TextInput
|
||||
label="Trigger Word"
|
||||
value={jobConfig.config.process[0].trigger_word || ''}
|
||||
docKey="config.process[0].trigger_word"
|
||||
onChange={(value: string | null) => {
|
||||
if (value?.trim() === '') {
|
||||
value = null;
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].trigger_word');
|
||||
}}
|
||||
placeholder=""
|
||||
required
|
||||
/>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
@@ -223,7 +263,7 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
)}
|
||||
</Card>
|
||||
{modelArch?.disableSections?.includes('model.quantize') ? null : (
|
||||
{disableSections.includes('model.quantize') ? null : (
|
||||
<Card title="Quantization">
|
||||
<SelectInput
|
||||
label="Transformer"
|
||||
@@ -270,14 +310,14 @@ export default function SimpleJob({
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target">
|
||||
@@ -319,7 +359,7 @@ export default function SimpleJob({
|
||||
max={1024}
|
||||
required
|
||||
/>
|
||||
{modelArch?.disableSections?.includes('network.conv') ? null : (
|
||||
{disableSections.includes('network.conv') ? null : (
|
||||
<NumberInput
|
||||
label="Conv Rank"
|
||||
value={jobConfig.config.process[0].network.conv}
|
||||
@@ -336,6 +376,38 @@ export default function SimpleJob({
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
{!disableSections.includes('slider') && (
|
||||
<Card title="Slider">
|
||||
<TextInput
|
||||
label="Target Class"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.target_class ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.target_class')}
|
||||
placeholder="eg. person"
|
||||
/>
|
||||
<TextInput
|
||||
label="Positive Prompt"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.positive_prompt ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.positive_prompt')}
|
||||
placeholder="eg. person who is happy"
|
||||
/>
|
||||
<TextInput
|
||||
label="Negative Prompt"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.negative_prompt ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.negative_prompt')}
|
||||
placeholder="eg. person who is sad"
|
||||
/>
|
||||
<TextInput
|
||||
label="Anchor Class"
|
||||
className=""
|
||||
value={jobConfig.config.process[0].slider?.anchor_class ?? ''}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].slider.anchor_class')}
|
||||
placeholder=""
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Save">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
@@ -367,7 +439,7 @@ export default function SimpleJob({
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div className={trainingBarClass}>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Batch Size"
|
||||
@@ -426,11 +498,11 @@ export default function SimpleJob({
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
{modelArch?.disableSections?.includes('train.timestep_type') ? null : (
|
||||
{disableSections.includes('train.timestep_type') ? null : (
|
||||
<SelectInput
|
||||
label="Timestep Type"
|
||||
value={jobConfig.config.process[0].train.timestep_type}
|
||||
disabled={modelArch?.disableSections?.includes('train.timestep_type') || false}
|
||||
disabled={disableSections.includes('train.timestep_type') || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.timestep_type')}
|
||||
options={[
|
||||
{ value: 'sigmoid', label: 'Sigmoid' },
|
||||
@@ -508,33 +580,39 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differtial Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
{disableSections.includes('train.diff_output_preservation') ? null : (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
<FormGroup label="Regularization">
|
||||
<Checkbox
|
||||
label="Differential Output Preservation"
|
||||
className="pt-1"
|
||||
checked={jobConfig.config.process[0].train.diff_output_preservation || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')
|
||||
}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -641,12 +719,20 @@ export default function SimpleJob({
|
||||
</FormGroup>
|
||||
<FormGroup label="Flipping" docKey={'datasets.flip'} className="mt-2">
|
||||
<Checkbox
|
||||
label={<>Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" /></>}
|
||||
label={
|
||||
<>
|
||||
Flip X <FlipHorizontal2 className="inline-block w-4 h-4 ml-1" />
|
||||
</>
|
||||
}
|
||||
checked={dataset.flip_x || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_x`)}
|
||||
/>
|
||||
<Checkbox
|
||||
label={<>Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" /></>}
|
||||
label={
|
||||
<>
|
||||
Flip Y <FlipVertical2 className="inline-block w-4 h-4 ml-1" />
|
||||
</>
|
||||
}
|
||||
checked={dataset.flip_y || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].flip_y`)}
|
||||
/>
|
||||
@@ -812,7 +898,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.skip_first_sample');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.force_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -827,7 +913,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.force_first_sample');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.skip_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -841,7 +927,7 @@ export default function SimpleJob({
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.disable_sampling');
|
||||
// cannot do both, so disable the other
|
||||
if (value){
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.force_first_sample');
|
||||
}
|
||||
}}
|
||||
@@ -866,6 +952,113 @@ export default function SimpleJob({
|
||||
placeholder="Enter prompt"
|
||||
required
|
||||
/>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4 mt-2">
|
||||
<TextInput
|
||||
label={`Width`}
|
||||
value={sample.width ? `${sample.width}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].width;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].width`);
|
||||
} else {
|
||||
console.warn('Invalid width value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.width} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`Height`}
|
||||
value={sample.height ? `${sample.height}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].height;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].height`);
|
||||
} else {
|
||||
console.warn('Invalid height value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.height} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`Seed`}
|
||||
value={sample.seed ? `${sample.seed}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric characters
|
||||
value = value.replace(/\D/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].seed;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const intValue = parseInt(value);
|
||||
if (!isNaN(intValue)) {
|
||||
setJobConfig(intValue, `config.process[0].sample.samples[${i}].seed`);
|
||||
} else {
|
||||
console.warn('Invalid seed value:', value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={`${jobConfig.config.process[0].sample.walk_seed ? jobConfig.config.process[0].sample.seed + i : jobConfig.config.process[0].sample.seed} (default)`}
|
||||
/>
|
||||
<TextInput
|
||||
label={`LoRA Scale`}
|
||||
value={sample.network_multiplier ? `${sample.network_multiplier}` : ''}
|
||||
onChange={value => {
|
||||
// remove any non-numeric, - or . characters
|
||||
value = value.replace(/[^0-9.-]/g, '');
|
||||
if (value === '') {
|
||||
// remove the key from the config if empty
|
||||
let newConfig = objectCopy(jobConfig);
|
||||
if (newConfig.config.process[0].sample.samples[i]) {
|
||||
delete newConfig.config.process[0].sample.samples[i].network_multiplier;
|
||||
setJobConfig(
|
||||
newConfig.config.process[0].sample.samples,
|
||||
'config.process[0].sample.samples',
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// set it as a string
|
||||
setJobConfig(value, `config.process[0].sample.samples[${i}].network_multiplier`);
|
||||
return;
|
||||
}
|
||||
}}
|
||||
placeholder={`1.0 (default)`}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { JobConfig, DatasetConfig } from '@/types';
|
||||
import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
@@ -20,13 +20,22 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
flip_y: false,
|
||||
};
|
||||
|
||||
export const defaultSliderConfig: SliderConfig = {
|
||||
guidance_strength: 3.0,
|
||||
anchor_strength: 1.0,
|
||||
positive_prompt: 'person who is happy',
|
||||
negative_prompt: 'person who is sad',
|
||||
target_class: 'person',
|
||||
anchor_class: "",
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
job: 'extension',
|
||||
config: {
|
||||
name: 'my_first_lora_v1',
|
||||
process: [
|
||||
{
|
||||
type: 'ui_trainer',
|
||||
type: 'diffusion_trainer',
|
||||
training_folder: 'output',
|
||||
sqlite_db_path: './aitk_db.db',
|
||||
device: 'cuda',
|
||||
@@ -100,7 +109,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
height: 1024,
|
||||
samples: [
|
||||
{
|
||||
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background'
|
||||
prompt: 'woman with red hair, playing chess at the park, bomb going off in the background',
|
||||
},
|
||||
{
|
||||
prompt: 'a woman holding a coffee cup, in a beanie, sitting at a cafe',
|
||||
@@ -109,7 +118,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
prompt: 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',
|
||||
},
|
||||
{
|
||||
prompt: 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
prompt:
|
||||
'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',
|
||||
},
|
||||
{
|
||||
prompt: 'a bear building a log cabin in the snow covered mountains',
|
||||
@@ -121,13 +131,15 @@ export const defaultJobConfig: JobConfig = {
|
||||
prompt: 'hipster man with a beard, building a chair, in a wood shop',
|
||||
},
|
||||
{
|
||||
prompt: 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
prompt:
|
||||
'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',
|
||||
},
|
||||
{
|
||||
prompt: "a man holding a sign that says, 'this is a sign'",
|
||||
},
|
||||
{
|
||||
prompt: 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
prompt:
|
||||
'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle',
|
||||
},
|
||||
],
|
||||
neg: '',
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
import { GroupedSelectOption, SelectOption, JobConfig } from '@/types';
|
||||
import { defaultSliderConfig } from './jobConfig';
|
||||
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type DisableableSections =
|
||||
| 'model.quantize'
|
||||
| 'train.timestep_type'
|
||||
| 'network.conv'
|
||||
| 'trigger_word'
|
||||
| 'train.diff_output_preservation'
|
||||
| 'slider';
|
||||
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.do_i2v'
|
||||
@@ -439,3 +447,33 @@ export const quantizationOptions: SelectOption[] = [
|
||||
];
|
||||
|
||||
export const defaultQtype = 'qfloat8';
|
||||
|
||||
interface JobTypeOption extends SelectOption {
|
||||
disableSections?: DisableableSections[];
|
||||
processSections?: string[];
|
||||
onActivate?: (config: JobConfig) => JobConfig;
|
||||
onDeactivate?: (config: JobConfig) => JobConfig;
|
||||
}
|
||||
|
||||
export const jobTypeOptions: JobTypeOption[] = [
|
||||
{
|
||||
value: 'diffusion_trainer',
|
||||
label: 'LoRA Trainer',
|
||||
disableSections: ['slider'],
|
||||
},
|
||||
{
|
||||
value: 'concept_slider',
|
||||
label: 'Concept Slider',
|
||||
disableSections: ['trigger_word', 'train.diff_output_preservation'],
|
||||
onActivate: (config: JobConfig) => {
|
||||
// add default slider config
|
||||
config.config.process[0].slider = { ...defaultSliderConfig };
|
||||
return config;
|
||||
},
|
||||
onDeactivate: (config: JobConfig) => {
|
||||
// remove slider config
|
||||
delete config.config.process[0].slider;
|
||||
return config;
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { defaultJobConfig, defaultDatasetConfig, migrateJobConfig } from './jobConfig';
|
||||
import { jobTypeOptions } from './options';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { useNestedState } from '@/utils/hooks';
|
||||
@@ -144,6 +145,38 @@ export default function TrainingForm() {
|
||||
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
||||
</>
|
||||
)}
|
||||
{!showAdvancedView && (
|
||||
<>
|
||||
<div>
|
||||
<SelectInput
|
||||
value={`${jobConfig?.config.process[0].type}`}
|
||||
onChange={value => {
|
||||
// undo current job type changes
|
||||
const currentOption = jobTypeOptions.find(
|
||||
option => option.value === jobConfig?.config.process[0].type,
|
||||
);
|
||||
if (currentOption && currentOption.onDeactivate) {
|
||||
setJobConfig(currentOption.onDeactivate(objectCopy(jobConfig)));
|
||||
}
|
||||
const option = jobTypeOptions.find(option => option.value === value);
|
||||
if (option) {
|
||||
if (option.onActivate) {
|
||||
setJobConfig(option.onActivate(objectCopy(jobConfig)));
|
||||
}
|
||||
jobTypeOptions.forEach(opt => {
|
||||
if (opt.value !== option.value && opt.onDeactivate) {
|
||||
setJobConfig(opt.onDeactivate(objectCopy(jobConfig)));
|
||||
}
|
||||
});
|
||||
}
|
||||
setJobConfig(value, 'config.process[0].type');
|
||||
}}
|
||||
options={jobTypeOptions}
|
||||
/>
|
||||
</div>
|
||||
<div className="mx-4 bg-gray-200 dark:bg-gray-800 w-1 h-6"></div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="pr-2">
|
||||
<Button
|
||||
|
||||
@@ -68,8 +68,8 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: Te
|
||||
TextInput.displayName = 'TextInput';
|
||||
|
||||
export interface NumberInputProps extends InputProps {
|
||||
value: number;
|
||||
onChange: (value: number) => void;
|
||||
value: number | null;
|
||||
onChange: (value: number | null) => void;
|
||||
min?: number;
|
||||
max?: number;
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ export interface ModelConfig {
|
||||
|
||||
export interface SampleItem {
|
||||
prompt: string;
|
||||
width?: number
|
||||
width?: number;
|
||||
height?: number;
|
||||
neg?: string;
|
||||
seed?: number;
|
||||
@@ -153,6 +153,7 @@ export interface SampleItem {
|
||||
num_frames?: number;
|
||||
ctrl_img?: string | null;
|
||||
ctrl_idx?: number;
|
||||
network_multiplier?: number;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
@@ -171,14 +172,24 @@ export interface SampleConfig {
|
||||
fps: number;
|
||||
}
|
||||
|
||||
export interface SliderConfig {
|
||||
guidance_strength?: number;
|
||||
anchor_strength?: number;
|
||||
positive_prompt?: string;
|
||||
negative_prompt?: string;
|
||||
target_class?: string;
|
||||
anchor_class?: string | null;
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
type: 'ui_trainer';
|
||||
type: string;
|
||||
sqlite_db_path?: string;
|
||||
training_folder: string;
|
||||
performance_log_every: number;
|
||||
trigger_word: string | null;
|
||||
device: string;
|
||||
network?: NetworkConfig;
|
||||
slider?: SliderConfig;
|
||||
save: SaveConfig;
|
||||
datasets: DatasetConfig[];
|
||||
train: TrainConfig;
|
||||
|
||||
Reference in New Issue
Block a user