Add simple ui settings to train Wan i2v models.

This commit is contained in:
Jaret Burkett
2025-07-11 11:28:40 -06:00
parent 6e2beef8dd
commit 8537a8557f
10 changed files with 155 additions and 31 deletions

View File

@@ -120,7 +120,7 @@ cd ai-toolkit
python3 -m venv venv python3 -m venv venv
source venv/bin/activate source venv/bin/activate
# install torch first # install torch first
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126 pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
pip3 install -r requirements.txt pip3 install -r requirements.txt
``` ```
@@ -130,7 +130,7 @@ git clone https://github.com/ostris/ai-toolkit.git
cd ai-toolkit cd ai-toolkit
python -m venv venv python -m venv venv
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126 pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -425,28 +425,32 @@ Everything else should work the same including layer targeting.
Only larger updates are listed here. There are usually smaller daily updated that are omitted. Only larger updates are listed here. There are usually smaller daily updated that are omitted.
### June 29, 2024 ### Jul 11, 2025
- Added better video config settings to the UI for video models.
- Added Wan I2V training to the UI
### June 29, 2025
- Fixed issue where Kontext forced sizes on sampling - Fixed issue where Kontext forced sizes on sampling
### June 26, 2024 ### June 26, 2025
- Added support for FLUX.1 Kontext training - Added support for FLUX.1 Kontext training
- added support for instruction dataset training - added support for instruction dataset training
### June 25, 2024 ### June 25, 2025
- Added support for OmniGen2 training - Added support for OmniGen2 training
- -
### June 17, 2024 ### June 17, 2025
- Performance optimizations for batch preparation - Performance optimizations for batch preparation
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP - Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
### June 16, 2024 ### June 16, 2025
- Hide control images in the UI when viewing datasets - Hide control images in the UI when viewing datasets
- WIP on mean flow loss - WIP on mean flow loss
### June 12, 2024 ### June 12, 2025
- Fixed issue that resulted in blank captions in the dataloader - Fixed issue that resulted in blank captions in the dataloader
### June 10, 2024 ### June 10, 2025
- Decided to keep track up updates in the readme - Decided to keep track up updates in the readme
- Added support for SDXL in the UI - Added support for SDXL in the UI
- Added support for SD 1.5 in the UI - Added support for SD 1.5 in the UI

View File

@@ -49,7 +49,7 @@ WORKDIR /app
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
# install pytorch before cache bust to avoid redownloading pytorch # install pytorch before cache bust to avoid redownloading pytorch
RUN pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
# Fix cache busting by moving CACHEBUST to right before git clone # Fix cache busting by moving CACHEBUST to right before git clone
ARG CACHEBUST=1234 ARG CACHEBUST=1234
@@ -63,7 +63,7 @@ WORKDIR /app/ai-toolkit
# Install Python dependencies # Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt && \ RUN pip install --no-cache-dir -r requirements.txt && \
pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \ pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
pip install setuptools==69.5.1 --no-cache-dir pip install setuptools==69.5.1 --no-cache-dir
# Build UI # Build UI

1
run.py
View File

@@ -1,5 +1,6 @@
import os import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import sys import sys
from typing import Union, OrderedDict from typing import Union, OrderedDict
from dotenv import load_dotenv from dotenv import load_dotenv

View File

@@ -1,6 +1,6 @@
'use client'; 'use client';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { modelArchs, ModelArch } from './options'; import { modelArchs, ModelArch, groupedModelOptions } from './options';
import { defaultDatasetConfig } from './jobConfig'; import { defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types'; import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic'; import { objectCopy } from '@/utils/basic';
@@ -37,7 +37,7 @@ export default function SimpleJob({
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]); }, [jobConfig.config.process[0].model.arch]);
const isVideoModel = !!modelArch?.isVideoModel; const isVideoModel = !!(modelArch?.group === 'video');
return ( return (
<> <>
@@ -100,8 +100,9 @@ export default function SimpleJob({
// set new model // set new model
setJobConfig(value, 'config.process[0].model.arch'); setJobConfig(value, 'config.process[0].model.arch');
// update controls for datasets // update datasets
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
const controls = newArch?.controls ?? []; const controls = newArch?.controls ?? [];
const datasets = jobConfig.config.process[0].datasets.map(dataset => { const datasets = jobConfig.config.process[0].datasets.map(dataset => {
const newDataset = objectCopy(dataset); const newDataset = objectCopy(dataset);
@@ -109,20 +110,14 @@ export default function SimpleJob({
if (!hasControlPath) { if (!hasControlPath) {
newDataset.control_path = null; // reset control path if not applicable newDataset.control_path = null; // reset control path if not applicable
} }
if (!hasNumFrames) {
newDataset.num_frames = 1; // reset num_frames if not applicable
}
return newDataset; return newDataset;
}); });
setJobConfig(datasets, 'config.process[0].datasets'); setJobConfig(datasets, 'config.process[0].datasets');
}} }}
options={ options={groupedModelOptions}
modelArchs
.map(model => {
return {
value: model.name,
label: model.label,
};
})
.filter(x => x) as { value: string; label: string }[]
}
/> />
<TextInput <TextInput
label="Name or Path" label="Name or Path"
@@ -422,6 +417,7 @@ export default function SimpleJob({
label="Control Dataset" label="Control Dataset"
docKey="datasets.control_path" docKey="datasets.control_path"
value={dataset.control_path ?? ''} value={dataset.control_path ?? ''}
className="pt-2"
onChange={value => onChange={value =>
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
} }
@@ -452,6 +448,18 @@ export default function SimpleJob({
min={0} min={0}
required required
/> />
{modelArch?.additionalSections?.includes('datasets.num_frames') && (
<NumberInput
label="Num Frames"
className="pt-2"
docKey="datasets.num_frames"
value={dataset.num_frames}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
placeholder="eg. 41"
min={1}
required
/>
)}
</div> </div>
<div> <div>
<FormGroup label="Settings" className=""> <FormGroup label="Settings" className="">
@@ -586,6 +594,7 @@ export default function SimpleJob({
value={jobConfig.config.process[0].sample.num_frames} value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')} onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0" placeholder="eg. 0"
className="pt-2"
min={0} min={0}
required required
/> />
@@ -594,6 +603,7 @@ export default function SimpleJob({
value={jobConfig.config.process[0].sample.fps} value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')} onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0" placeholder="eg. 0"
className="pt-2"
min={0} min={0}
required required
/> />

View File

@@ -12,7 +12,9 @@ export const defaultDatasetConfig: DatasetConfig = {
is_reg: false, is_reg: false,
network_weight: 1, network_weight: 1,
resolution: [512, 768, 1024], resolution: [512, 768, 1024],
controls: [] controls: [],
shrink_video_to_frames: true,
num_frames: 1,
}; };
export const defaultJobConfig: JobConfig = { export const defaultJobConfig: JobConfig = {

View File

@@ -1,11 +1,15 @@
import { GroupedSelectOption } from "@/types";
type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames';
type ModelGroup = 'image' | 'video';
export interface ModelArch { export interface ModelArch {
name: string; name: string;
label: string; label: string;
group: ModelGroup;
controls?: Control[]; controls?: Control[];
isVideoModel?: boolean; isVideoModel?: boolean;
defaults?: { [key: string]: any }; defaults?: { [key: string]: any };
@@ -16,10 +20,13 @@ export interface ModelArch {
const defaultNameOrPath = ''; const defaultNameOrPath = '';
export const modelArchs: ModelArch[] = [ export const modelArchs: ModelArch[] = [
{ {
name: 'flux', name: 'flux',
label: 'FLUX.1', label: 'FLUX.1',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath], 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
@@ -33,6 +40,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'flux_kontext', name: 'flux_kontext',
label: 'FLUX.1-Kontext-dev', label: 'FLUX.1-Kontext-dev',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
@@ -48,6 +56,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'flex1', name: 'flex1',
label: 'Flex.1', label: 'Flex.1',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath], 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
@@ -62,6 +71,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'flex2', name: 'flex2',
label: 'Flex.2', label: 'Flex.2',
group: 'image',
controls: ['depth', 'line', 'pose', 'inpaint'], controls: ['depth', 'line', 'pose', 'inpaint'],
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
@@ -89,6 +99,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'chroma', name: 'chroma',
label: 'Chroma', label: 'Chroma',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath], 'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath],
@@ -102,6 +113,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'wan21:1b', name: 'wan21:1b',
label: 'Wan 2.1 (1.3B)', label: 'Wan 2.1 (1.3B)',
group: 'video',
isVideoModel: true, isVideoModel: true,
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
@@ -110,14 +122,52 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1], 'config.process[0].sample.fps': [15, 1],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames'],
},
{
name: 'wan21_i2v:14b480p',
label: 'Wan 2.1 I2V (14B-480P)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
},
{
name: 'wan21_i2v:14b',
label: 'Wan 2.1 I2V (14B-720P)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
}, },
{ {
name: 'wan21:14b', name: 'wan21:14b',
label: 'Wan 2.1 (14B)', label: 'Wan 2.1 (14B)',
group: 'video',
isVideoModel: true, isVideoModel: true,
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
@@ -126,14 +176,16 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1], 'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1], 'config.process[0].sample.fps': [15, 1],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames'],
}, },
{ {
name: 'lumina2', name: 'lumina2',
label: 'Lumina2', label: 'Lumina2',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath], 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
@@ -147,6 +199,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'hidream', name: 'hidream',
label: 'HiDream', label: 'HiDream',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath], 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
@@ -163,6 +216,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'sdxl', name: 'sdxl',
label: 'SDXL', label: 'SDXL',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
@@ -177,6 +231,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'sd15', name: 'sd15',
label: 'SD 1.5', label: 'SD 1.5',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
@@ -191,6 +246,7 @@ export const modelArchs: ModelArch[] = [
{ {
name: 'omnigen2', name: 'omnigen2',
label: 'OmniGen2', label: 'OmniGen2',
group: 'image',
defaults: { defaults: {
// default updates when [selected, unselected] in the UI // default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
@@ -206,3 +262,17 @@ export const modelArchs: ModelArch[] = [
// Sort by label, case-insensitive // Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }) return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
}) as any; }) as any;
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
const group = acc.find(g => g.label === arch.group);
if (group) {
group.options.push({ value: arch.name, label: arch.label });
} else {
acc.push({
label: arch.group,
options: [{ value: arch.name, label: arch.label }],
});
}
return acc;
}, [] as GroupedSelectOption[]);

View File

@@ -6,6 +6,7 @@ import dynamic from 'next/dynamic';
import { CircleHelp } from 'lucide-react'; import { CircleHelp } from 'lucide-react';
import { getDoc } from '@/docs'; import { getDoc } from '@/docs';
import { openDoc } from '@/components/DocModal'; import { openDoc } from '@/components/DocModal';
import { GroupedSelectOption, SelectOption } from '@/types';
const Select = dynamic(() => import('react-select'), { ssr: false }); const Select = dynamic(() => import('react-select'), { ssr: false });
@@ -141,13 +142,21 @@ export interface SelectInputProps extends InputProps {
value: string; value: string;
disabled?: boolean; disabled?: boolean;
onChange: (value: string) => void; onChange: (value: string) => void;
options: { value: string; label: string }[]; options: GroupedSelectOption[] | SelectOption[];
} }
export const SelectInput = (props: SelectInputProps) => { export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options, docKey = null } = props; const { label, value, onChange, options, docKey = null } = props;
const doc = getDoc(docKey); const doc = getDoc(docKey);
const selectedOption = options.find(option => option.value === value); let selectedOption: SelectOption | undefined;
if (options && options.length > 0) {
// see if grouped options
if ('options' in options[0]) {
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
} else {
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
}
}
return ( return (
<div <div
className={classNames(props.className, { className={classNames(props.className, {

View File

@@ -57,6 +57,23 @@ const docs: { [key: string]: ConfigDoc } = {
</> </>
), ),
}, },
'datasets.num_frames': {
title: 'Number of Frames',
description: (
<>
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
<br/>
<br/>
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
So you would want all of your videos trimmed to around 5 seconds for best results.
<br/>
<br/>
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
</>
),
},
}; };
export const getDoc = (key: string | null | undefined): ConfigDoc | null => { export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -84,6 +84,8 @@ export interface DatasetConfig {
resolution: number[]; resolution: number[];
controls: string[]; controls: string[];
control_path: string | null; control_path: string | null;
num_frames: number;
shrink_video_to_frames: boolean;
} }
export interface EMAConfig { export interface EMAConfig {
@@ -181,3 +183,12 @@ export interface ConfigDoc {
title: string; title: string;
description: React.ReactNode; description: React.ReactNode;
} }
export interface SelectOption {
readonly value: string;
readonly label: string;
}
export interface GroupedSelectOption {
readonly label: string;
readonly options: SelectOption[];
}

View File

@@ -1 +1 @@
VERSION = "0.3.5" VERSION = "0.3.6"