diff --git a/ui/src/app/api/jobs/[jobID]/samples/route.ts b/ui/src/app/api/jobs/[jobID]/samples/route.ts index 26af0c05..2a98a6ea 100644 --- a/ui/src/app/api/jobs/[jobID]/samples/route.ts +++ b/ui/src/app/api/jobs/[jobID]/samples/route.ts @@ -29,7 +29,7 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s const samples = fs .readdirSync(samplesFolder) .filter(file => { - return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg'); + return file.endsWith('.png') || file.endsWith('.jpg') || file.endsWith('.jpeg') || file.endsWith('.webp'); }) .map(file => { return path.join(samplesFolder, file); diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 8098171d..c03fd18c 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -57,6 +57,7 @@ export const defaultJobConfig: JobConfig = { optimizer_params: { weight_decay: 1e-4 }, + unload_text_encoder: false, lr: 0.0001, ema_config: { use_ema: true, @@ -70,9 +71,10 @@ export const defaultJobConfig: JobConfig = { }, model: { name_or_path: 'ostris/Flex.1-alpha', - is_flux: true, quantize: true, - quantize_te: true + quantize_te: true, + arch: 'flux', + low_vram: false, }, sample: { sampler: 'flowmatch', @@ -96,6 +98,8 @@ export const defaultJobConfig: JobConfig = { walk_seed: true, guidance_scale: 4, sample_steps: 25, + num_frames: 1, + fps: 1, }, }, ], diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 05361cda..e45f02ee 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -8,6 +8,19 @@ export interface Option { model: Model[]; } +export const modelArchs = [ + { name: 'flux', label: 'Flux.1' }, + { name: 'wan21', label: 'Wan 2.1' }, + { name: 'lumina2', label: 'Lumina2' }, +] + +export const isVideoModelFromArch = (arch: string) => { + const videoArches = ['wan21']; + return videoArches.includes(arch); +}; + +const defaultModelArch = 'flux'; + export const options = { model: [ { @@ -16,7 +29,7 @@ export const options = { // default updates when [selected, unselected] in the UI 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.is_flux': [true, false], + 'config.process[0].model.arch': ['flux', defaultModelArch], 'config.process[0].train.bypass_guidance_embedding': [true, false], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], @@ -28,18 +41,44 @@ export const options = { // default updates when [selected, unselected] in the UI 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.is_flux': [true, false], + 'config.process[0].model.arch': ['flux', defaultModelArch], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, }, + { + name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [false, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.arch': ['wan21', defaultModelArch], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [40, 1], + 'config.process[0].sample.fps': [15, 1], + }, + }, + { + name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers', + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].model.arch': ['wan21', defaultModelArch], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [40, 1], + 'config.process[0].sample.fps': [15, 1], + }, + }, { name_or_path: 'Alpha-VLLM/Lumina-Image-2.0', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.quantize': [false, false], 'config.process[0].model.quantize_te': [true, false], - 'config.process[0].model.is_lumina2': [true, false], + 'config.process[0].model.arch': ['lumina2', defaultModelArch], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], }, @@ -50,6 +89,7 @@ export const options = { defaults: { 'config.process[0].sample.sampler': ['ddpm', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['ddpm', 'flowmatch'], + 'config.process[0].model.arch': ['sd1', defaultModelArch], }, }, ], diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index 4652cf20..871d63c2 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -2,7 +2,7 @@ import { useEffect, useState } from 'react'; import { useSearchParams, useRouter } from 'next/navigation'; -import { options } from './options'; +import { options, modelArchs, isVideoModelFromArch } from './options'; import { defaultJobConfig, defaultDatasetConfig } from './jobConfig'; import { JobConfig } from '@/types'; import { objectCopy } from '@/utils/basic'; @@ -33,6 +33,8 @@ export default function TrainingForm() { const [jobConfig, setJobConfig] = useNestedState(objectCopy(defaultJobConfig)); const [status, setStatus] = useState<'idle' | 'saving' | 'success' | 'error'>('idle'); + const isVideoModel = isVideoModelFromArch(jobConfig.config.process[0].model.arch); + useEffect(() => { if (!isSettingsLoaded) return; if (datasetFetchStatus !== 'success') return; @@ -212,6 +214,30 @@ export default function TrainingForm() { .filter(x => x) as { value: string; label: string }[] } /> + { + const currentArch = modelArchs.find( + a => a.name === jobConfig.config.process[0].model.arch, + ); + if (!currentArch || currentArch.name === value) { + return; + } + // set new model + setJobConfig(value, 'config.process[0].model.arch'); + }} + options={ + modelArchs + .map(model => { + return { + value: model.name, + label: model.label, + }; + }) + .filter(x => x) as { value: string; label: string }[] + } + />
+ +
+ setJobConfig(value, 'config.process[0].train.unload_text_encoder')} + /> +
+
@@ -552,7 +587,7 @@ export default function TrainingForm() {
-
+
setJobConfig(value, 'config.process[0].sample.walk_seed')} />
+ { isVideoModel && ( +
+ setJobConfig(value, 'config.process[0].sample.num_frames')} + placeholder="eg. 0" + min={0} + required + /> + setJobConfig(value, 'config.process[0].sample.fps')} + placeholder="eg. 0" + min={0} + required + /> +
+ )}