mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-01 16:49:56 +00:00
Added ui sopport for multi control samples and datasets. Added qwen image edit 5209 to the ui
This commit is contained in:
@@ -15,7 +15,9 @@ import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/comp
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import AddSingleImageModal, { openAddImageModal } from '@/components/AddSingleImageModal';
|
||||
import SampleControlImage from '@/components/SampleControlImage';
|
||||
import { FlipHorizontal2, FlipVertical2 } from 'lucide-react';
|
||||
import { handleModelArchChange } from './utils';
|
||||
|
||||
type Props = {
|
||||
jobConfig: JobConfig;
|
||||
@@ -185,58 +187,7 @@ export default function SimpleJob({
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
onChange={value => {
|
||||
const currentArch = modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch);
|
||||
if (!currentArch || currentArch.name === value) {
|
||||
return;
|
||||
}
|
||||
// update the defaults when a model is selected
|
||||
const newArch = modelArchs.find(model => model.name === value);
|
||||
|
||||
// update vram setting
|
||||
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentArch.defaults) {
|
||||
setJobConfig(currentArch.defaults[key][1], key);
|
||||
}
|
||||
|
||||
if (newArch?.defaults) {
|
||||
for (const key in newArch.defaults) {
|
||||
setJobConfig(newArch.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
|
||||
// update datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
if (!hasControlPath) {
|
||||
newDataset.control_path = null; // reset control path if not applicable
|
||||
}
|
||||
if (!hasNumFrames) {
|
||||
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
|
||||
// update samples
|
||||
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
|
||||
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
|
||||
const newSample = objectCopy(sample);
|
||||
if (!hasSampleCtrlImg) {
|
||||
delete newSample.ctrl_img; // remove ctrl_img if not applicable
|
||||
}
|
||||
return newSample;
|
||||
});
|
||||
setJobConfig(samples, 'config.process[0].sample.samples');
|
||||
handleModelArchChange(jobConfig.config.process[0].model.arch, value, jobConfig, setJobConfig);
|
||||
}}
|
||||
options={groupedModelOptions}
|
||||
/>
|
||||
@@ -557,17 +508,19 @@ export default function SimpleJob({
|
||||
)}
|
||||
|
||||
<FormGroup label="Text Encoder Optimizations" className="pt-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{!disableSections.includes('train.unload_text_encoder') && (
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={value => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder');
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<Checkbox
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
@@ -642,7 +595,7 @@ export default function SimpleJob({
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div>
|
||||
<SelectInput
|
||||
label="Dataset"
|
||||
label="Target Dataset"
|
||||
value={dataset.folder_path}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].folder_path`)}
|
||||
options={datasetOptions}
|
||||
@@ -659,6 +612,49 @@ export default function SimpleJob({
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('datasets.multi_control_paths') && (
|
||||
<>
|
||||
<SelectInput
|
||||
label="Control Dataset 1"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_1 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_1`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Control Dataset 2"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_2 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_2`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Control Dataset 3"
|
||||
docKey="datasets.multi_control_paths"
|
||||
value={dataset.control_path_3 ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(
|
||||
value == '' ? null : value,
|
||||
`config.process[0].datasets[${i}].control_path_3`,
|
||||
)
|
||||
}
|
||||
options={[{ value: '', label: <> </> }, ...datasetOptions]}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<NumberInput
|
||||
label="LoRA Weight"
|
||||
value={dataset.network_weight}
|
||||
@@ -1062,30 +1058,43 @@ export default function SimpleJob({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modelArch?.additionalSections?.includes('datasets.multi_control_paths') && (
|
||||
<FormGroup label="Control Images" className="pt-2 ml-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-2 mt-2 mt-2">
|
||||
{['ctrl_img_1', 'ctrl_img_2', 'ctrl_img_3'].map((ctrlKey, ctrl_idx) => (
|
||||
<SampleControlImage
|
||||
key={ctrlKey}
|
||||
instruction={`Add Control Image ${ctrl_idx + 1}`}
|
||||
className=""
|
||||
src={sample[ctrlKey as keyof typeof sample] as string}
|
||||
onNewImageSelected={imagePath => {
|
||||
if (!imagePath) {
|
||||
let newSamples = objectCopy(jobConfig.config.process[0].sample.samples);
|
||||
delete newSamples[i][ctrlKey as keyof typeof sample];
|
||||
setJobConfig(newSamples, 'config.process[0].sample.samples');
|
||||
} else {
|
||||
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].${ctrlKey}`);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</FormGroup>
|
||||
)}
|
||||
{modelArch?.additionalSections?.includes('sample.ctrl_img') && (
|
||||
<div
|
||||
className="h-14 w-14 mt-2 ml-4 border border-gray-500 flex items-center justify-center rounded cursor-pointer hover:bg-gray-700 transition-colors"
|
||||
style={{
|
||||
backgroundImage: sample.ctrl_img
|
||||
? `url(${`/api/img/${encodeURIComponent(sample.ctrl_img)}`})`
|
||||
: 'none',
|
||||
backgroundSize: 'cover',
|
||||
backgroundPosition: 'center',
|
||||
marginBottom: '-1rem',
|
||||
}}
|
||||
onClick={() => {
|
||||
openAddImageModal(imagePath => {
|
||||
console.log('Selected image path:', imagePath);
|
||||
if (!imagePath) return;
|
||||
<SampleControlImage
|
||||
className="mt-6 ml-4"
|
||||
src={sample.ctrl_img}
|
||||
onNewImageSelected={imagePath => {
|
||||
if (!imagePath) {
|
||||
let newSamples = objectCopy(jobConfig.config.process[0].sample.samples);
|
||||
delete newSamples[i].ctrl_img;
|
||||
setJobConfig(newSamples, 'config.process[0].sample.samples');
|
||||
} else {
|
||||
setJobConfig(imagePath, `config.process[0].sample.samples[${i}].ctrl_img`);
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{!sample.ctrl_img && (
|
||||
<div className="text-gray-400 text-xs text-center font-bold">Add Control Image</div>
|
||||
)}
|
||||
</div>
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="pb-4"></div>
|
||||
|
||||
@@ -2,7 +2,6 @@ import { JobConfig, DatasetConfig, SliderConfig } from '@/types';
|
||||
|
||||
export const defaultDatasetConfig: DatasetConfig = {
|
||||
folder_path: '/path/to/images/folder',
|
||||
control_path: null,
|
||||
mask_path: null,
|
||||
mask_min_value: 0.1,
|
||||
default_caption: '',
|
||||
|
||||
@@ -9,12 +9,15 @@ type DisableableSections =
|
||||
| 'network.conv'
|
||||
| 'trigger_word'
|
||||
| 'train.diff_output_preservation'
|
||||
| 'train.unload_text_encoder'
|
||||
| 'slider';
|
||||
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.multi_control_paths'
|
||||
| 'datasets.do_i2v'
|
||||
| 'sample.ctrl_img'
|
||||
| 'sample.multi_ctrl_imgs'
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.low_vram';
|
||||
@@ -335,6 +338,28 @@ export const modelArchs: ModelArch[] = [
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'qwen_image_edit_plus',
|
||||
label: 'Qwen-Image-Edit-2509',
|
||||
group: 'instruction',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Qwen/Qwen-Image-Edit-2509', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.unload_text_encoder': [false, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
'config.process[0].model.qtype': ['qfloat8', 'qfloat8'],
|
||||
},
|
||||
disableSections: ['network.conv', 'train.unload_text_encoder'],
|
||||
additionalSections: ['datasets.multi_control_paths', 'sample.multi_ctrl_imgs', 'model.low_vram'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/qwen_image_edit_2509_torchao_uint3.safetensors',
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'hidream',
|
||||
label: 'HiDream',
|
||||
|
||||
105
ui/src/app/jobs/new/utils.ts
Normal file
105
ui/src/app/jobs/new/utils.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
import { GroupedSelectOption, JobConfig, SelectOption } from '@/types';
|
||||
import { modelArchs, ModelArch } from './options';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
|
||||
export const handleModelArchChange = (
|
||||
currentArchName: string,
|
||||
newArchName: string,
|
||||
jobConfig: JobConfig,
|
||||
setJobConfig: (value: any, key: string) => void,
|
||||
) => {
|
||||
const currentArch = modelArchs.find(a => a.name === currentArchName);
|
||||
if (!currentArch || currentArch.name === newArchName) {
|
||||
return;
|
||||
}
|
||||
|
||||
// update the defaults when a model is selected
|
||||
const newArch = modelArchs.find(model => model.name === newArchName);
|
||||
|
||||
// update vram setting
|
||||
if (!newArch?.additionalSections?.includes('model.low_vram')) {
|
||||
setJobConfig(false, 'config.process[0].model.low_vram');
|
||||
}
|
||||
|
||||
// revert defaults from previous model
|
||||
for (const key in currentArch.defaults) {
|
||||
setJobConfig(currentArch.defaults[key][1], key);
|
||||
}
|
||||
|
||||
if (newArch?.defaults) {
|
||||
for (const key in newArch.defaults) {
|
||||
setJobConfig(newArch.defaults[key][0], key);
|
||||
}
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(newArchName, 'config.process[0].model.arch');
|
||||
|
||||
// update datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const hasMultiControlPaths = newArch?.additionalSections?.includes('datasets.multi_control_paths') || false;
|
||||
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
if (hasMultiControlPaths) {
|
||||
// make sure the config has the multi control paths
|
||||
newDataset.control_path_1 = newDataset.control_path_1 || null;
|
||||
newDataset.control_path_2 = newDataset.control_path_2 || null;
|
||||
newDataset.control_path_3 = newDataset.control_path_3 || null;
|
||||
// if we previously had a single control path and now
|
||||
// we selected a multi control model
|
||||
if (newDataset.control_path && newDataset.control_path !== '') {
|
||||
// only set if not overwriting
|
||||
if (!newDataset.control_path_1) {
|
||||
newDataset.control_path_1 = newDataset.control_path;
|
||||
}
|
||||
}
|
||||
delete newDataset.control_path; // remove single control path
|
||||
} else if (hasControlPath) {
|
||||
newDataset.control_path = newDataset.control_path || null;
|
||||
if (newDataset.control_path_1 && newDataset.control_path_1 !== '') {
|
||||
newDataset.control_path = newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
} else {
|
||||
// does not have control images
|
||||
if (newDataset.control_path) {
|
||||
delete newDataset.control_path;
|
||||
}
|
||||
if (newDataset.control_path_1) {
|
||||
delete newDataset.control_path_1;
|
||||
}
|
||||
if (newDataset.control_path_2) {
|
||||
delete newDataset.control_path_2;
|
||||
}
|
||||
if (newDataset.control_path_3) {
|
||||
delete newDataset.control_path_3;
|
||||
}
|
||||
}
|
||||
if (!hasNumFrames) {
|
||||
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
|
||||
// update samples
|
||||
const hasSampleCtrlImg = newArch?.additionalSections?.includes('sample.ctrl_img') || false;
|
||||
const samples = jobConfig.config.process[0].sample.samples.map(sample => {
|
||||
const newSample = objectCopy(sample);
|
||||
if (!hasSampleCtrlImg) {
|
||||
delete newSample.ctrl_img; // remove ctrl_img if not applicable
|
||||
}
|
||||
return newSample;
|
||||
});
|
||||
setJobConfig(samples, 'config.process[0].sample.samples');
|
||||
};
|
||||
Reference in New Issue
Block a user