Add support for wan training in ui

This commit is contained in:
Jaret Burkett
2025-03-13 18:54:27 -06:00
parent 31e057d9a3
commit cf4216e6b8
5 changed files with 112 additions and 10 deletions

View File

@@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s
const samples = fs
.readdirSync(samplesFolder)
.filter(file => {
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg');
return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp');
})
.map(file => {
return path.join(samplesFolder, file);

View File

@@ -57,6 +57,7 @@ export const defaultJobConfig: JobConfig = {
optimizer_params: {
weight_decay: 1e-4
},
unload_text_encoder: false,
lr: 0.0001,
ema_config: {
use_ema: true,
@@ -70,9 +71,10 @@ export const defaultJobConfig: JobConfig = {
},
model: {
name_or_path: 'ostris/Flex.1-alpha',
is_flux: true,
quantize: true,
quantize_te: true
quantize_te: true,
arch: 'flux',
low_vram: false,
},
sample: {
sampler: 'flowmatch',
@@ -96,6 +98,8 @@ export const defaultJobConfig: JobConfig = {
walk_seed: true,
guidance_scale: 4,
sample_steps: 25,
num_frames: 1,
fps: 1,
},
},
],

View File

@@ -8,6 +8,19 @@ export interface Option {
model: Model[];
}
export const modelArchs = [
{ name: 'flux', label: 'Flux.1' },
{ name: 'wan21', label: 'Wan 2.1' },
{ name: 'lumina2', label: 'Lumina2' },
]
export const isVideoModelFromArch = (arch: string) => {
const videoArches = ['wan21'];
return videoArches.includes(arch);
};
const defaultModelArch = 'flux';
export const options = {
model: [
{
@@ -16,7 +29,7 @@ export const options = {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].model.arch': ['flux', defaultModelArch],
'config.process[0].train.bypass_guidance_embedding': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
@@ -28,18 +41,44 @@ export const options = {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_flux': [true, false],
'config.process[0].model.arch': ['flux', defaultModelArch],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
},
{
name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.arch': ['wan21', defaultModelArch],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
},
{
name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.arch': ['wan21', defaultModelArch],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.fps': [15, 1],
},
},
{
name_or_path: 'Alpha-VLLM/Lumina-Image-2.0',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.quantize': [false, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.is_lumina2': [true, false],
'config.process[0].model.arch': ['lumina2', defaultModelArch],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
},
@@ -50,6 +89,7 @@ export const options = {
defaults: {
'config.process[0].sample.sampler': ['ddpm', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'],
'config.process[0].model.arch': ['sd1', defaultModelArch],
},
},
],

View File

@@ -2,7 +2,7 @@
import { useEffect, useState } from 'react';
import { useSearchParams, useRouter } from 'next/navigation';
import { options } from './options';
import { options, modelArchs, isVideoModelFromArch } from './options';
import { defaultJobConfig, defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
@@ -33,6 +33,8 @@ export default function TrainingForm() {
const [jobConfig, setJobConfig] = useNestedState<JobConfig>(objectCopy(defaultJobConfig));
const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle');
const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch);
useEffect(() => {
if (!isSettingsLoaded) return;
if (datasetFetchStatus !== 'success') return;
@@ -212,6 +214,30 @@ export default function TrainingForm() {
.filter(x => x) as { value: string; label: string }[]
}
/>
<SelectInput
label="Model Architecture"
value={jobConfig.config.process[0].model.arch}
onChange={value => {
const currentArch = modelArchs.find(
a => a.name === jobConfig.config.process[0].model.arch,
);
if (!currentArch || currentArch.name === value) {
return;
}
// set new model
setJobConfig(value, 'config.process[0].model.arch');
}}
options={
modelArchs
.map(model => {
return {
value: model.name,
label: model.label,
};
})
.filter(x => x) as { value: string; label: string }[]
}
/>
<FormGroup label="Quantize">
<div className="grid grid-cols-2 gap-2">
<Checkbox
@@ -406,6 +432,15 @@ export default function TrainingForm() {
placeholder="eg. 0.99"
min={0}
/>
<FormGroup label="Unload Text Encoder" className="pt-2">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
/>
</div>
</FormGroup>
</div>
<div>
<FormGroup label="Regularization">
@@ -552,7 +587,7 @@ export default function TrainingForm() {
</div>
<div>
<Card title="Sample Configuration">
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6">
<div className={isVideoModel ? "grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6" : "grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-6"}>
<div>
<NumberInput
label="Sample Every"
@@ -628,6 +663,26 @@ export default function TrainingForm() {
onChange={value => setJobConfig(value, 'config.process[0].sample.walk_seed')}
/>
</div>
{ isVideoModel && (
<div>
<NumberInput
label="Num Frames"
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
min={0}
required
/>
<NumberInput
label="FPS"
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
min={0}
required
/>
</div>
)}
</div>
<FormGroup
label={`Sample Prompts (${jobConfig.config.process[0].sample.prompts.length})`}

View File

@@ -99,6 +99,7 @@ export interface TrainConfig {
lr: number;
ema_config?: EMAConfig;
dtype: string;
unload_text_encoder: boolean;
optimizer_params: {
weight_decay: number;
};
@@ -113,11 +114,11 @@ export interface QuantizeKwargsConfig {
export interface ModelConfig {
name_or_path: string;
is_flux?: boolean;
is_lumina2?: boolean;
quantize: boolean;
quantize_te: boolean;
quantize_kwargs?: QuantizeKwargsConfig;
arch: string;
low_vram: boolean;
}
export interface SampleConfig {
@@ -131,6 +132,8 @@ export interface SampleConfig {
walk_seed: boolean;
guidance_scale: number;
sample_steps: number;
num_frames: number;
fps: number;
}
export interface ProcessConfig {