From 8537a8557fda9eb25c30809bfdbab8c42d00a95a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 11 Jul 2025 11:28:40 -0600 Subject: [PATCH] Add simple ui settings to train Wan i2v models. --- README.md | 22 +++++---- docker/Dockerfile | 4 +- run.py | 1 + ui/src/app/jobs/new/SimpleJob.tsx | 36 +++++++++------ ui/src/app/jobs/new/jobConfig.ts | 4 +- ui/src/app/jobs/new/options.ts | 76 +++++++++++++++++++++++++++++-- ui/src/components/formInputs.tsx | 13 +++++- ui/src/docs.tsx | 17 +++++++ ui/src/types.ts | 11 +++++ version.py | 2 +- 10 files changed, 155 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 68020639..db05500f 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ cd ai-toolkit python3 -m venv venv source venv/bin/activate # install torch first -pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126 +pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126 pip3 install -r requirements.txt ``` @@ -130,7 +130,7 @@ git clone https://github.com/ostris/ai-toolkit.git cd ai-toolkit python -m venv venv .\venv\Scripts\activate -pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126 +pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126 pip install -r requirements.txt ``` @@ -425,28 +425,32 @@ Everything else should work the same including layer targeting. Only larger updates are listed here. There are usually smaller daily updated that are omitted. -### June 29, 2024 +### Jul 11, 2025 +- Added better video config settings to the UI for video models. +- Added Wan I2V training to the UI + +### June 29, 2025 - Fixed issue where Kontext forced sizes on sampling -### June 26, 2024 +### June 26, 2025 - Added support for FLUX.1 Kontext training - added support for instruction dataset training -### June 25, 2024 +### June 25, 2025 - Added support for OmniGen2 training - -### June 17, 2024 +### June 17, 2025 - Performance optimizations for batch preparation - Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP -### June 16, 2024 +### June 16, 2025 - Hide control images in the UI when viewing datasets - WIP on mean flow loss -### June 12, 2024 +### June 12, 2025 - Fixed issue that resulted in blank captions in the dataloader -### June 10, 2024 +### June 10, 2025 - Decided to keep track up updates in the readme - Added support for SDXL in the UI - Added support for SD 1.5 in the UI diff --git a/docker/Dockerfile b/docker/Dockerfile index 377b5164..a4f68f44 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -49,7 +49,7 @@ WORKDIR /app RUN ln -s /usr/bin/python3 /usr/bin/python # install pytorch before cache bust to avoid redownloading pytorch -RUN pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 +RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 # Fix cache busting by moving CACHEBUST to right before git clone ARG CACHEBUST=1234 @@ -63,7 +63,7 @@ WORKDIR /app/ai-toolkit # Install Python dependencies RUN pip install --no-cache-dir -r requirements.txt && \ - pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \ + pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \ pip install setuptools==69.5.1 --no-cache-dir # Build UI diff --git a/run.py b/run.py index 4c36046d..1a2c5c98 100644 --- a/run.py +++ b/run.py @@ -1,5 +1,6 @@ import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" import sys from typing import Union, OrderedDict from dotenv import load_dotenv diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 82675e3b..157bcc6c 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -1,6 +1,6 @@ 'use client'; import { useMemo } from 'react'; -import { modelArchs, ModelArch } from './options'; +import { modelArchs, ModelArch, groupedModelOptions } from './options'; import { defaultDatasetConfig } from './jobConfig'; import { JobConfig } from '@/types'; import { objectCopy } from '@/utils/basic'; @@ -37,7 +37,7 @@ export default function SimpleJob({ return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch; }, [jobConfig.config.process[0].model.arch]); - const isVideoModel = !!modelArch?.isVideoModel; + const isVideoModel = !!(modelArch?.group === 'video'); return ( <> @@ -100,8 +100,9 @@ export default function SimpleJob({ // set new model setJobConfig(value, 'config.process[0].model.arch'); - // update controls for datasets + // update datasets const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false; + const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false; const controls = newArch?.controls ?? []; const datasets = jobConfig.config.process[0].datasets.map(dataset => { const newDataset = objectCopy(dataset); @@ -109,20 +110,14 @@ export default function SimpleJob({ if (!hasControlPath) { newDataset.control_path = null; // reset control path if not applicable } + if (!hasNumFrames) { + newDataset.num_frames = 1; // reset num_frames if not applicable + } return newDataset; }); setJobConfig(datasets, 'config.process[0].datasets'); }} - options={ - modelArchs - .map(model => { - return { - value: model.name, - label: model.label, - }; - }) - .filter(x => x) as { value: string; label: string }[] - } + options={groupedModelOptions} /> setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`) } @@ -452,6 +448,18 @@ export default function SimpleJob({ min={0} required /> + {modelArch?.additionalSections?.includes('datasets.num_frames') && ( + setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)} + placeholder="eg. 41" + min={1} + required + /> + )}
@@ -586,6 +594,7 @@ export default function SimpleJob({ value={jobConfig.config.process[0].sample.num_frames} onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')} placeholder="eg. 0" + className="pt-2" min={0} required /> @@ -594,6 +603,7 @@ export default function SimpleJob({ value={jobConfig.config.process[0].sample.fps} onChange={value => setJobConfig(value, 'config.process[0].sample.fps')} placeholder="eg. 0" + className="pt-2" min={0} required /> diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 66b7f7bf..94c0de6a 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -12,7 +12,9 @@ export const defaultDatasetConfig: DatasetConfig = { is_reg: false, network_weight: 1, resolution: [512, 768, 1024], - controls: [] + controls: [], + shrink_video_to_frames: true, + num_frames: 1, }; export const defaultJobConfig: JobConfig = { diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 3d43c630..5c34f15a 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -1,11 +1,15 @@ +import { GroupedSelectOption } from "@/types"; + type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; -type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' +type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames'; +type ModelGroup = 'image' | 'video'; export interface ModelArch { name: string; label: string; + group: ModelGroup; controls?: Control[]; isVideoModel?: boolean; defaults?: { [key: string]: any }; @@ -16,10 +20,13 @@ export interface ModelArch { const defaultNameOrPath = ''; + + export const modelArchs: ModelArch[] = [ { name: 'flux', label: 'FLUX.1', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath], @@ -33,6 +40,7 @@ export const modelArchs: ModelArch[] = [ { name: 'flux_kontext', label: 'FLUX.1-Kontext-dev', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath], @@ -48,6 +56,7 @@ export const modelArchs: ModelArch[] = [ { name: 'flex1', label: 'Flex.1', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath], @@ -62,6 +71,7 @@ export const modelArchs: ModelArch[] = [ { name: 'flex2', label: 'Flex.2', + group: 'image', controls: ['depth', 'line', 'pose', 'inpaint'], defaults: { // default updates when [selected, unselected] in the UI @@ -89,6 +99,7 @@ export const modelArchs: ModelArch[] = [ { name: 'chroma', label: 'Chroma', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath], @@ -102,6 +113,7 @@ export const modelArchs: ModelArch[] = [ { name: 'wan21:1b', label: 'Wan 2.1 (1.3B)', + group: 'video', isVideoModel: true, defaults: { // default updates when [selected, unselected] in the UI @@ -110,14 +122,52 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].model.quantize_te': [true, false], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [40, 1], + 'config.process[0].sample.num_frames': [41, 1], 'config.process[0].sample.fps': [15, 1], }, disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames'], + }, + { + name: 'wan21_i2v:14b480p', + label: 'Wan 2.1 I2V (14B-480P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [15, 1], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames'], + }, + { + name: 'wan21_i2v:14b', + label: 'Wan 2.1 I2V (14B-720P)', + group: 'video', + isVideoModel: true, + defaults: { + // default updates when [selected, unselected] in the UI + 'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath], + 'config.process[0].model.quantize': [true, false], + 'config.process[0].model.quantize_te': [true, false], + 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], + 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], + 'config.process[0].sample.num_frames': [41, 1], + 'config.process[0].sample.fps': [15, 1], + }, + disableSections: ['network.conv'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames'], }, { name: 'wan21:14b', label: 'Wan 2.1 (14B)', + group: 'video', isVideoModel: true, defaults: { // default updates when [selected, unselected] in the UI @@ -126,14 +176,16 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].model.quantize_te': [true, false], 'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'], 'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'], - 'config.process[0].sample.num_frames': [40, 1], + 'config.process[0].sample.num_frames': [41, 1], 'config.process[0].sample.fps': [15, 1], }, disableSections: ['network.conv'], + additionalSections: ['datasets.num_frames'], }, { name: 'lumina2', label: 'Lumina2', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath], @@ -147,6 +199,7 @@ export const modelArchs: ModelArch[] = [ { name: 'hidream', label: 'HiDream', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath], @@ -163,6 +216,7 @@ export const modelArchs: ModelArch[] = [ { name: 'sdxl', label: 'SDXL', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath], @@ -177,6 +231,7 @@ export const modelArchs: ModelArch[] = [ { name: 'sd15', label: 'SD 1.5', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath], @@ -191,6 +246,7 @@ export const modelArchs: ModelArch[] = [ { name: 'omnigen2', label: 'OmniGen2', + group: 'image', defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath], @@ -206,3 +262,17 @@ export const modelArchs: ModelArch[] = [ // Sort by label, case-insensitive return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' }) }) as any; + + +export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => { + const group = acc.find(g => g.label === arch.group); + if (group) { + group.options.push({ value: arch.name, label: arch.label }); + } else { + acc.push({ + label: arch.group, + options: [{ value: arch.name, label: arch.label }], + }); + } + return acc; +}, [] as GroupedSelectOption[]); diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index a556a266..5613cb4e 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -6,6 +6,7 @@ import dynamic from 'next/dynamic'; import { CircleHelp } from 'lucide-react'; import { getDoc } from '@/docs'; import { openDoc } from '@/components/DocModal'; +import { GroupedSelectOption, SelectOption } from '@/types'; const Select = dynamic(() => import('react-select'), { ssr: false }); @@ -141,13 +142,21 @@ export interface SelectInputProps extends InputProps { value: string; disabled?: boolean; onChange: (value: string) => void; - options: { value: string; label: string }[]; + options: GroupedSelectOption[] | SelectOption[]; } export const SelectInput = (props: SelectInputProps) => { const { label, value, onChange, options, docKey = null } = props; const doc = getDoc(docKey); - const selectedOption = options.find(option => option.value === value); + let selectedOption: SelectOption | undefined; + if (options && options.length > 0) { + // see if grouped options + if ('options' in options[0]) { + selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value); + } else { + selectedOption = (options as SelectOption[]).find(opt => opt.value === value); + } + } return (
), }, + 'datasets.num_frames': { + title: 'Number of Frames', + description: ( + <> + This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame. + If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset. +
+
+ It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video. + So you would want all of your videos trimmed to around 5 seconds for best results. +
+
+ Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81 + evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast. + + ), + }, }; export const getDoc = (key: string | null | undefined): ConfigDoc | null => { diff --git a/ui/src/types.ts b/ui/src/types.ts index c4e76da9..02e89b5c 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -84,6 +84,8 @@ export interface DatasetConfig { resolution: number[]; controls: string[]; control_path: string | null; + num_frames: number; + shrink_video_to_frames: boolean; } export interface EMAConfig { @@ -181,3 +183,12 @@ export interface ConfigDoc { title: string; description: React.ReactNode; } + +export interface SelectOption { + readonly value: string; + readonly label: string; +} +export interface GroupedSelectOption { + readonly label: string; + readonly options: SelectOption[]; +} diff --git a/version.py b/version.py index 460b359a..090987e5 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.5" \ No newline at end of file +VERSION = "0.3.6" \ No newline at end of file