mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for Flex.2 in the UI
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
'use client';
|
||||
|
||||
import { options, modelArchs, isVideoModelFromArch } from './options';
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch } from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -33,7 +33,13 @@ export default function SimpleJob({
|
||||
gpuList,
|
||||
datasetOptions,
|
||||
}: Props) {
|
||||
const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch);
|
||||
|
||||
const modelArch = useMemo(() => {
|
||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||
}, [jobConfig.config.process[0].model.arch]);
|
||||
|
||||
const isVideoModel = !!modelArch?.isVideoModel;
|
||||
|
||||
return (
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
@@ -91,6 +97,16 @@ export default function SimpleJob({
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
|
||||
// update controls for datasets
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
newDataset.controls = controls;
|
||||
return newDataset;
|
||||
}
|
||||
);
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
}}
|
||||
options={
|
||||
modelArchs
|
||||
@@ -445,12 +461,16 @@ export default function SimpleJob({
|
||||
))}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
onClick={() => {
|
||||
const newDataset = objectCopy(defaultDatasetConfig);
|
||||
// automaticallt add the controls for a new dataset
|
||||
const controls = modelArch?.controls ?? [];
|
||||
newDataset.controls = controls;
|
||||
setJobConfig(
|
||||
[...jobConfig.config.process[0].datasets, objectCopy(defaultDatasetConfig)],
|
||||
[...jobConfig.config.process[0].datasets, newDataset],
|
||||
'config.process[0].datasets',
|
||||
)
|
||||
}
|
||||
}}
|
||||
className="w-full px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors"
|
||||
>
|
||||
Add Dataset
|
||||
|
||||
@@ -11,6 +11,7 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
is_reg: false,
|
||||
network_weight: 1,
|
||||
resolution: [512, 768, 1024],
|
||||
controls: []
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
@@ -75,6 +76,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
quantize_te: true,
|
||||
arch: 'flex1',
|
||||
low_vram: false,
|
||||
model_kwargs: {},
|
||||
},
|
||||
sample: {
|
||||
sampler: 'flowmatch',
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
export interface Model {
|
||||
name_or_path: string;
|
||||
arch: string;
|
||||
dev_only?: boolean;
|
||||
defaults?: { [key: string]: any };
|
||||
}
|
||||
|
||||
export interface Option {
|
||||
model: Model[];
|
||||
}
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
export interface ModelArch {
|
||||
name: string;
|
||||
label: string;
|
||||
controls?: Control[];
|
||||
isVideoModel?: boolean;
|
||||
defaults?: { [key: string]: [any, any] };
|
||||
}
|
||||
|
||||
@@ -43,7 +37,32 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
// { name: 'flex2', label: 'Flex.2' },
|
||||
{
|
||||
name: 'flex2',
|
||||
label: 'Flex.2',
|
||||
controls: ['depth', 'line', 'pose', 'inpaint'],
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['ostris/Flex.2-preview', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
invert_inpaint_mask_chance: 0.2,
|
||||
inpaint_dropout: 0.5,
|
||||
control_dropout: 0.5,
|
||||
inpaint_random_chance: 0.2,
|
||||
do_random_inpainting: true,
|
||||
random_blur_mask: true,
|
||||
random_dialate_mask: true,
|
||||
},
|
||||
{},
|
||||
],
|
||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'chroma',
|
||||
label: 'Chroma',
|
||||
@@ -59,6 +78,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'wan21:1b',
|
||||
label: 'Wan 2.1 (1.3B)',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-1.3B-Diffusers', defaultNameOrPath],
|
||||
@@ -73,6 +93,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'wan21:14b',
|
||||
label: 'Wan 2.1 (14B)',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-T2V-14B-Diffuserss', defaultNameOrPath],
|
||||
@@ -112,109 +133,3 @@ export const modelArchs: ModelArch[] = [
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export const isVideoModelFromArch = (arch: string) => {
|
||||
const videoArches = ['wan21'];
|
||||
return videoArches.includes(arch);
|
||||
};
|
||||
|
||||
const defaultModelArch = 'flux';
|
||||
|
||||
export const options: Option = {
|
||||
model: [
|
||||
{
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
arch: 'flex1',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'black-forest-labs/FLUX.1-dev',
|
||||
arch: 'flux',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'lodestones/Chroma',
|
||||
arch: 'chroma',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers',
|
||||
arch: 'wan21',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers',
|
||||
arch: 'wan21',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
|
||||
arch: 'lumina2',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'HiDream-ai/HiDream-I1-Full',
|
||||
arch: 'hidream',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.lr': [0.0002, 0.0001],
|
||||
'config.process[0].train.timestep_type': ['shift', 'sigmoid'],
|
||||
'config.process[0].network.network_kwargs.ignore_if_contains': [['ff_i.experts', 'ff_i.gate'], []],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'ostris/objective-reality',
|
||||
arch: 'sd1',
|
||||
dev_only: true,
|
||||
defaults: {
|
||||
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as Option;
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { options, modelArchs, isVideoModelFromArch } from './options';
|
||||
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
import { useNestedState } from '@/utils/hooks';
|
||||
import { TextInput, SelectInput, Checkbox, FormGroup, NumberInput } from '@/components/formInputs';
|
||||
import Card from '@/components/Card';
|
||||
import { X } from 'lucide-react';
|
||||
import { SelectInput} from '@/components/formInputs';
|
||||
import useSettings from '@/hooks/useSettings';
|
||||
import useGPUInfo from '@/hooks/useGPUInfo';
|
||||
import useDatasetList from '@/hooks/useDatasetList';
|
||||
|
||||
@@ -80,6 +80,7 @@ export interface DatasetConfig {
|
||||
network_weight: number;
|
||||
cache_latents_to_disk?: boolean;
|
||||
resolution: number[];
|
||||
controls: string[];
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
@@ -122,6 +123,7 @@ export interface ModelConfig {
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
arch: string;
|
||||
low_vram: boolean;
|
||||
model_kwargs: {[key: string]: any};
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
|
||||
Reference in New Issue
Block a user