Added support for Flex.2 in the UI

This commit is contained in:
Jaret Burkett
2025-05-07 12:41:51 -06:00
parent 43cb5603ad
commit 25e150b370
5 changed files with 62 additions and 126 deletions

View File

@@ -1,6 +1,6 @@
'use client';
import { options, modelArchs, isVideoModelFromArch } from './options';
import { useMemo } from 'react';
import { modelArchs, ModelArch } from './options';
import { defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
@@ -33,7 +33,13 @@ export default function SimpleJob({
gpuList,
datasetOptions,
}: Props) {
const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch);
const modelArch = useMemo(() => {
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]);
const isVideoModel = !!modelArch?.isVideoModel;
return (
<>
<form onSubmit={handleSubmit} className="space-y-8">
@@ -91,6 +97,16 @@ export default function SimpleJob({
}
// set new model
setJobConfig(value, 'config.process[0].model.arch');
// update controls for datasets
const controls = newArch?.controls ?? [];
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
const newDataset = objectCopy(dataset);
newDataset.controls = controls;
return newDataset;
}
);
setJobConfig(datasets, 'config.process[0].datasets');
}}
options={
modelArchs
@@ -445,12 +461,16 @@ export default function SimpleJob({
))}
<button
type="button"
onClick={() =>
onClick={() => {
const newDataset = objectCopy(defaultDatasetConfig);
// automaticallt add the controls for a new dataset
const controls = modelArch?.controls ?? [];
newDataset.controls = controls;
setJobConfig(
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
[...jobConfig.config.process[0].datasets, newDataset],
'config.process[0].datasets',
)
}
}}
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
>
Add Dataset

View File

@@ -11,6 +11,7 @@ export const defaultDatasetConfig: DatasetConfig = {
is_reg: false,
network_weight: 1,
resolution: [512, 768, 1024],
controls: []
};
export const defaultJobConfig: JobConfig = {
@@ -75,6 +76,7 @@ export const defaultJobConfig: JobConfig = {
quantize_te: true,
arch: 'flex1',
low_vram: false,
model_kwargs: {},
},
sample: {
sampler: 'flowmatch',

View File

@@ -1,17 +1,11 @@
export interface Model {
name_or_path: string;
arch: string;
dev_only?: boolean;
defaults?: { [key: string]: any };
}
export interface Option {
model: Model[];
}
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
export interface ModelArch {
name: string;
label: string;
controls?: Control[];
isVideoModel?: boolean;
defaults?: { [key: string]: [any, any] };
}
@@ -43,7 +37,32 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
// { name: 'flex2', label: 'Flex.2' },
{
name: 'flex2',
label: 'Flex.2',
controls: ['depth', 'line', 'pose', 'inpaint'],
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.model_kwargs': [
{
invert_inpaint_mask_chance: 0.2,
inpaint_dropout: 0.5,
control_dropout: 0.5,
inpaint_random_chance: 0.2,
do_random_inpainting: true,
random_blur_mask: true,
random_dialate_mask: true,
},
{},
],
'config.process[0].train.bypass_guidance_embedding': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name: 'chroma',
label: 'Chroma',
@@ -59,6 +78,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'wan21:1b',
label: 'Wan 2.1 (1.3B)',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath],
@@ -73,6 +93,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'wan21:14b',
label: 'Wan 2.1 (14B)',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffuserss', defaultNameOrPath],
@@ -112,109 +133,3 @@ export const modelArchs: ModelArch[] = [
},
},
];
export const isVideoModelFromArch = (arch: string) => {
const videoArches = ['wan21'];
return videoArches.includes(arch);
};
const defaultModelArch = 'flux';
export const options: Option = {
model: [
{
name_or_path: 'ostris/Flex.1-alpha',
arch: 'flex1',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].train.bypass_guidance_embedding': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'black-forest-labs/FLUX.1-dev',
arch: 'flux',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'lodestones/Chroma',
arch: 'chroma',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers',
arch: 'wan21',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
},
{
name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers',
arch: 'wan21',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
},
{
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
arch: 'lumina2',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'HiDream-ai/HiDream-I1-Full',
arch: 'hidream',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].train.lr': [0.0002, 0.0001],
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
},
},
{
name_or_path: 'ostris/objective-reality',
arch: 'sd1',
dev_only: true,
defaults: {
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
},
},
],
} as Option;

View File

@@ -2,14 +2,11 @@
import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { options, modelArchs, isVideoModelFromArch } from './options';
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
import { useNestedState } from '@/utils/hooks';
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
import Card from '@/components/Card';
import { X } from 'lucide-react';
import { SelectInput} from '@/components/formInputs';
import useSettings from '@/hooks/useSettings';
import useGPUInfo from '@/hooks/useGPUInfo';
import useDatasetList from '@/hooks/useDatasetList';

View File

@@ -80,6 +80,7 @@ export interface DatasetConfig {
network_weight: number;
cache_latents_to_disk?: boolean;
resolution: number[];
controls: string[];
}
export interface EMAConfig {
@@ -122,6 +123,7 @@ export interface ModelConfig {
quantize_kwargs?: QuantizeKwargsConfig;
arch: string;
low_vram: boolean;
model_kwargs: {[key: string]: any};
}
export interface SampleConfig {