Add simple ui settings to train Wan i2v models.

This commit is contained in:
Jaret Burkett
2025-07-11 11:28:40 -06:00
parent 6e2beef8dd
commit 8537a8557f
10 changed files with 155 additions and 31 deletions

View File

@@ -120,7 +120,7 @@ cd ai-toolkit
python3 -m venv venv
source venv/bin/activate
# install torch first
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
pip3 install -r requirements.txt
```
@@ -130,7 +130,7 @@ git clone https://github.com/ostris/ai-toolkit.git
cd ai-toolkit
python -m venv venv
.\venv\Scripts\activate
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
pip install -r requirements.txt
```
@@ -425,28 +425,32 @@ Everything else should work the same including layer targeting.
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
### June 29, 2024
### Jul 11, 2025
- Added better video config settings to the UI for video models.
- Added Wan I2V training to the UI
### June 29, 2025
- Fixed issue where Kontext forced sizes on sampling
### June 26, 2024
### June 26, 2025
- Added support for FLUX.1 Kontext training
- added support for instruction dataset training
### June 25, 2024
### June 25, 2025
- Added support for OmniGen2 training
-
### June 17, 2024
### June 17, 2025
- Performance optimizations for batch preparation
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
### June 16, 2024
### June 16, 2025
- Hide control images in the UI when viewing datasets
- WIP on mean flow loss
### June 12, 2024
### June 12, 2025
- Fixed issue that resulted in blank captions in the dataloader
### June 10, 2024
### June 10, 2025
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI
- Added support for SD 1.5 in the UI

View File

@@ -49,7 +49,7 @@ WORKDIR /app
RUN ln -s /usr/bin/python3 /usr/bin/python
# install pytorch before cache bust to avoid redownloading pytorch
RUN pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
# Fix cache busting by moving CACHEBUST to right before git clone
ARG CACHEBUST=1234
@@ -63,7 +63,7 @@ WORKDIR /app/ai-toolkit
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt && \
pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
pip install setuptools==69.5.1 --no-cache-dir
# Build UI

1
run.py
View File

@@ -1,5 +1,6 @@
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import sys
from typing import Union, OrderedDict
from dotenv import load_dotenv

View File

@@ -1,6 +1,6 @@
'use client';
import { useMemo } from 'react';
import { modelArchs, ModelArch } from './options';
import { modelArchs, ModelArch, groupedModelOptions } from './options';
import { defaultDatasetConfig } from './jobConfig';
import { JobConfig } from '@/types';
import { objectCopy } from '@/utils/basic';
@@ -37,7 +37,7 @@ export default function SimpleJob({
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
}, [jobConfig.config.process[0].model.arch]);
const isVideoModel = !!modelArch?.isVideoModel;
const isVideoModel = !!(modelArch?.group === 'video');
return (
<>
@@ -100,8 +100,9 @@ export default function SimpleJob({
// set new model
setJobConfig(value, 'config.process[0].model.arch');
// update controls for datasets
// update datasets
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
const controls = newArch?.controls ?? [];
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
const newDataset = objectCopy(dataset);
@@ -109,20 +110,14 @@ export default function SimpleJob({
if (!hasControlPath) {
newDataset.control_path = null; // reset control path if not applicable
}
if (!hasNumFrames) {
newDataset.num_frames = 1; // reset num_frames if not applicable
}
return newDataset;
});
setJobConfig(datasets, 'config.process[0].datasets');
}}
options={
modelArchs
.map(model => {
return {
value: model.name,
label: model.label,
};
})
.filter(x => x) as { value: string; label: string }[]
}
options={groupedModelOptions}
/>
<TextInput
label="Name or Path"
@@ -422,6 +417,7 @@ export default function SimpleJob({
label="Control Dataset"
docKey="datasets.control_path"
value={dataset.control_path ?? ''}
className="pt-2"
onChange={value =>
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
}
@@ -452,6 +448,18 @@ export default function SimpleJob({
min={0}
required
/>
{modelArch?.additionalSections?.includes('datasets.num_frames') && (
<NumberInput
label="Num Frames"
className="pt-2"
docKey="datasets.num_frames"
value={dataset.num_frames}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
placeholder="eg. 41"
min={1}
required
/>
)}
</div>
<div>
<FormGroup label="Settings" className="">
@@ -586,6 +594,7 @@ export default function SimpleJob({
value={jobConfig.config.process[0].sample.num_frames}
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
placeholder="eg. 0"
className="pt-2"
min={0}
required
/>
@@ -594,6 +603,7 @@ export default function SimpleJob({
value={jobConfig.config.process[0].sample.fps}
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
placeholder="eg. 0"
className="pt-2"
min={0}
required
/>

View File

@@ -12,7 +12,9 @@ export const defaultDatasetConfig: DatasetConfig = {
is_reg: false,
network_weight: 1,
resolution: [512, 768, 1024],
controls: []
controls: [],
shrink_video_to_frames: true,
num_frames: 1,
};
export const defaultJobConfig: JobConfig = {

View File

@@ -1,11 +1,15 @@
import { GroupedSelectOption } from "@/types";
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img'
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames';
type ModelGroup = 'image' | 'video';
export interface ModelArch {
name: string;
label: string;
group: ModelGroup;
controls?: Control[];
isVideoModel?: boolean;
defaults?: { [key: string]: any };
@@ -16,10 +20,13 @@ export interface ModelArch {
const defaultNameOrPath = '';
export const modelArchs: ModelArch[] = [
{
name: 'flux',
label: 'FLUX.1',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
@@ -33,6 +40,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'flux_kontext',
label: 'FLUX.1-Kontext-dev',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
@@ -48,6 +56,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'flex1',
label: 'Flex.1',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
@@ -62,6 +71,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'flex2',
label: 'Flex.2',
group: 'image',
controls: ['depth', 'line', 'pose', 'inpaint'],
defaults: {
// default updates when [selected, unselected] in the UI
@@ -89,6 +99,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'chroma',
label: 'Chroma',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath],
@@ -102,6 +113,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'wan21:1b',
label: 'Wan 2.1 (1.3B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
@@ -110,14 +122,52 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames'],
},
{
name: 'wan21_i2v:14b480p',
label: 'Wan 2.1 I2V (14B-480P)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
},
{
name: 'wan21_i2v:14b',
label: 'Wan 2.1 I2V (14B-720P)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
},
{
name: 'wan21:14b',
label: 'Wan 2.1 (14B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
@@ -126,14 +176,16 @@ export const modelArchs: ModelArch[] = [
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [40, 1],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames'],
},
{
name: 'lumina2',
label: 'Lumina2',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
@@ -147,6 +199,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'hidream',
label: 'HiDream',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
@@ -163,6 +216,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'sdxl',
label: 'SDXL',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
@@ -177,6 +231,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'sd15',
label: 'SD 1.5',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
@@ -191,6 +246,7 @@ export const modelArchs: ModelArch[] = [
{
name: 'omnigen2',
label: 'OmniGen2',
group: 'image',
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
@@ -206,3 +262,17 @@ export const modelArchs: ModelArch[] = [
// Sort by label, case-insensitive
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
}) as any;
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
const group = acc.find(g => g.label === arch.group);
if (group) {
group.options.push({ value: arch.name, label: arch.label });
} else {
acc.push({
label: arch.group,
options: [{ value: arch.name, label: arch.label }],
});
}
return acc;
}, [] as GroupedSelectOption[]);

View File

@@ -6,6 +6,7 @@ import dynamic from 'next/dynamic';
import { CircleHelp } from 'lucide-react';
import { getDoc } from '@/docs';
import { openDoc } from '@/components/DocModal';
import { GroupedSelectOption, SelectOption } from '@/types';
const Select = dynamic(() => import('react-select'), { ssr: false });
@@ -141,13 +142,21 @@ export interface SelectInputProps extends InputProps {
value: string;
disabled?: boolean;
onChange: (value: string) => void;
options: { value: string; label: string }[];
options: GroupedSelectOption[] | SelectOption[];
}
export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options, docKey = null } = props;
const doc = getDoc(docKey);
const selectedOption = options.find(option => option.value === value);
let selectedOption: SelectOption | undefined;
if (options && options.length > 0) {
// see if grouped options
if ('options' in options[0]) {
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
} else {
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
}
}
return (
<div
className={classNames(props.className, {

View File

@@ -57,6 +57,23 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'datasets.num_frames': {
title: 'Number of Frames',
description: (
<>
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
<br/>
<br/>
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
So you would want all of your videos trimmed to around 5 seconds for best results.
<br/>
<br/>
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -84,6 +84,8 @@ export interface DatasetConfig {
resolution: number[];
controls: string[];
control_path: string | null;
num_frames: number;
shrink_video_to_frames: boolean;
}
export interface EMAConfig {
@@ -181,3 +183,12 @@ export interface ConfigDoc {
title: string;
description: React.ReactNode;
}
export interface SelectOption {
readonly value: string;
readonly label: string;
}
export interface GroupedSelectOption {
readonly label: string;
readonly options: SelectOption[];
}

View File

@@ -1 +1 @@
VERSION = "0.3.5"
VERSION = "0.3.6"