mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-28 16:23:56 +00:00
Add support for wan training in ui
This commit is contained in:
@@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
|
||||
const samples = fs
|
||||
.readdirSync(samplesFolder)
|
||||
.filter(file => {
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
|
||||
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
|
||||
})
|
||||
.map(file => {
|
||||
return path.join(samplesFolder, file);
|
||||
|
||||
@@ -57,6 +57,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
optimizer_params: {
|
||||
weight_decay: 1e-4
|
||||
},
|
||||
unload_text_encoder: false,
|
||||
lr: 0.0001,
|
||||
ema_config: {
|
||||
use_ema: true,
|
||||
@@ -70,9 +71,10 @@ export const defaultJobConfig: JobConfig = {
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
is_flux: true,
|
||||
quantize: true,
|
||||
quantize_te: true
|
||||
quantize_te: true,
|
||||
arch: 'flux',
|
||||
low_vram: false,
|
||||
},
|
||||
sample: {
|
||||
sampler: 'flowmatch',
|
||||
@@ -96,6 +98,8 @@ export const defaultJobConfig: JobConfig = {
|
||||
walk_seed: true,
|
||||
guidance_scale: 4,
|
||||
sample_steps: 25,
|
||||
num_frames: 1,
|
||||
fps: 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
@@ -8,6 +8,19 @@ export interface Option {
|
||||
model: Model[];
|
||||
}
|
||||
|
||||
export const modelArchs = [
|
||||
{ name: 'flux', label: 'Flux.1' },
|
||||
{ name: 'wan21', label: 'Wan 2.1' },
|
||||
{ name: 'lumina2', label: 'Lumina2' },
|
||||
]
|
||||
|
||||
export const isVideoModelFromArch = (arch: string) => {
|
||||
const videoArches = ['wan21'];
|
||||
return videoArches.includes(arch);
|
||||
};
|
||||
|
||||
const defaultModelArch = 'flux';
|
||||
|
||||
export const options = {
|
||||
model: [
|
||||
{
|
||||
@@ -16,7 +29,7 @@ export const options = {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
'config.process[0].model.arch': ['flux', defaultModelArch],
|
||||
'config.process[0].train.bypass_guidance_embedding': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
@@ -28,18 +41,44 @@ export const options = {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_flux': [true, false],
|
||||
'config.process[0].model.arch': ['flux', defaultModelArch],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.arch': ['wan21', defaultModelArch],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.arch': ['wan21', defaultModelArch],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
},
|
||||
{
|
||||
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.quantize': [false, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].model.is_lumina2': [true, false],
|
||||
'config.process[0].model.arch': ['lumina2', defaultModelArch],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
},
|
||||
@@ -50,6 +89,7 @@ export const options = {
|
||||
defaults: {
|
||||
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
|
||||
'config.process[0].model.arch': ['sd1', defaultModelArch],
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useSearchParams, useRouter } from 'next/navigation';
|
||||
import { options } from './options';
|
||||
import { options, modelArchs, isVideoModelFromArch } from './options';
|
||||
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -33,6 +33,8 @@ export default function TrainingForm() {
|
||||
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
|
||||
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
|
||||
|
||||
const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isSettingsLoaded) return;
|
||||
if (datasetFetchStatus !== 'success') return;
|
||||
@@ -212,6 +214,30 @@ export default function TrainingForm() {
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
/>
|
||||
<SelectInput
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
onChange={value => {
|
||||
const currentArch = modelArchs.find(
|
||||
a => a.name === jobConfig.config.process[0].model.arch,
|
||||
);
|
||||
if (!currentArch || currentArch.name === value) {
|
||||
return;
|
||||
}
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
}}
|
||||
options={
|
||||
modelArchs
|
||||
.map(model => {
|
||||
return {
|
||||
value: model.name,
|
||||
label: model.label,
|
||||
};
|
||||
})
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
/>
|
||||
<FormGroup label="Quantize">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
@@ -406,6 +432,15 @@ export default function TrainingForm() {
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
<FormGroup label="Unload Text Encoder" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
|
||||
/>
|
||||
</div>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Regularization">
|
||||
@@ -552,7 +587,7 @@ export default function TrainingForm() {
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
|
||||
<div className={isVideoModel ? "grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6" : "grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6"}>
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Sample Every"
|
||||
@@ -628,6 +663,26 @@ export default function TrainingForm() {
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
|
||||
/>
|
||||
</div>
|
||||
{ isVideoModel && (
|
||||
<div>
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
<NumberInput
|
||||
label="FPS"
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<FormGroup
|
||||
label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`}
|
||||
|
||||
@@ -99,6 +99,7 @@ export interface TrainConfig {
|
||||
lr: number;
|
||||
ema_config?: EMAConfig;
|
||||
dtype: string;
|
||||
unload_text_encoder: boolean;
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
@@ -113,11 +114,11 @@ export interface QuantizeKwargsConfig {
|
||||
|
||||
export interface ModelConfig {
|
||||
name_or_path: string;
|
||||
is_flux?: boolean;
|
||||
is_lumina2?: boolean;
|
||||
quantize: boolean;
|
||||
quantize_te: boolean;
|
||||
quantize_kwargs?: QuantizeKwargsConfig;
|
||||
arch: string;
|
||||
low_vram: boolean;
|
||||
}
|
||||
|
||||
export interface SampleConfig {
|
||||
@@ -131,6 +132,8 @@ export interface SampleConfig {
|
||||
walk_seed: boolean;
|
||||
guidance_scale: number;
|
||||
sample_steps: number;
|
||||
num_frames: number;
|
||||
fps: number;
|
||||
}
|
||||
|
||||
export interface ProcessConfig {
|
||||
|
||||
Reference in New Issue
Block a user