mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Add simple ui settings to train Wan i2v models.
This commit is contained in:
22
README.md
22
README.md
@@ -120,7 +120,7 @@ cd ai-toolkit
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
# install torch first
|
||||
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -130,7 +130,7 @@ git clone https://github.com/ostris/ai-toolkit.git
|
||||
cd ai-toolkit
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -425,28 +425,32 @@ Everything else should work the same including layer targeting.
|
||||
|
||||
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
|
||||
|
||||
### June 29, 2024
|
||||
### Jul 11, 2025
|
||||
- Added better video config settings to the UI for video models.
|
||||
- Added Wan I2V training to the UI
|
||||
|
||||
### June 29, 2025
|
||||
- Fixed issue where Kontext forced sizes on sampling
|
||||
|
||||
### June 26, 2024
|
||||
### June 26, 2025
|
||||
- Added support for FLUX.1 Kontext training
|
||||
- added support for instruction dataset training
|
||||
|
||||
### June 25, 2024
|
||||
### June 25, 2025
|
||||
- Added support for OmniGen2 training
|
||||
-
|
||||
### June 17, 2024
|
||||
### June 17, 2025
|
||||
- Performance optimizations for batch preparation
|
||||
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
|
||||
|
||||
### June 16, 2024
|
||||
### June 16, 2025
|
||||
- Hide control images in the UI when viewing datasets
|
||||
- WIP on mean flow loss
|
||||
|
||||
### June 12, 2024
|
||||
### June 12, 2025
|
||||
- Fixed issue that resulted in blank captions in the dataloader
|
||||
|
||||
### June 10, 2024
|
||||
### June 10, 2025
|
||||
- Decided to keep track up updates in the readme
|
||||
- Added support for SDXL in the UI
|
||||
- Added support for SD 1.5 in the UI
|
||||
|
||||
@@ -49,7 +49,7 @@ WORKDIR /app
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# install pytorch before cache bust to avoid redownloading pytorch
|
||||
RUN pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
|
||||
# Fix cache busting by moving CACHEBUST to right before git clone
|
||||
ARG CACHEBUST=1234
|
||||
@@ -63,7 +63,7 @@ WORKDIR /app/ai-toolkit
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt && \
|
||||
pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
|
||||
pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
|
||||
pip install setuptools==69.5.1 --no-cache-dir
|
||||
|
||||
# Build UI
|
||||
|
||||
1
run.py
1
run.py
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
||||
import sys
|
||||
from typing import Union, OrderedDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client';
|
||||
import { useMemo } from 'react';
|
||||
import { modelArchs, ModelArch } from './options';
|
||||
import { modelArchs, ModelArch, groupedModelOptions } from './options';
|
||||
import { defaultDatasetConfig } from './jobConfig';
|
||||
import { JobConfig } from '@/types';
|
||||
import { objectCopy } from '@/utils/basic';
|
||||
@@ -37,7 +37,7 @@ export default function SimpleJob({
|
||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||
}, [jobConfig.config.process[0].model.arch]);
|
||||
|
||||
const isVideoModel = !!modelArch?.isVideoModel;
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -100,8 +100,9 @@ export default function SimpleJob({
|
||||
// set new model
|
||||
setJobConfig(value, 'config.process[0].model.arch');
|
||||
|
||||
// update controls for datasets
|
||||
// update datasets
|
||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||
const controls = newArch?.controls ?? [];
|
||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||
const newDataset = objectCopy(dataset);
|
||||
@@ -109,20 +110,14 @@ export default function SimpleJob({
|
||||
if (!hasControlPath) {
|
||||
newDataset.control_path = null; // reset control path if not applicable
|
||||
}
|
||||
if (!hasNumFrames) {
|
||||
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||
}
|
||||
return newDataset;
|
||||
});
|
||||
setJobConfig(datasets, 'config.process[0].datasets');
|
||||
}}
|
||||
options={
|
||||
modelArchs
|
||||
.map(model => {
|
||||
return {
|
||||
value: model.name,
|
||||
label: model.label,
|
||||
};
|
||||
})
|
||||
.filter(x => x) as { value: string; label: string }[]
|
||||
}
|
||||
options={groupedModelOptions}
|
||||
/>
|
||||
<TextInput
|
||||
label="Name or Path"
|
||||
@@ -422,6 +417,7 @@ export default function SimpleJob({
|
||||
label="Control Dataset"
|
||||
docKey="datasets.control_path"
|
||||
value={dataset.control_path ?? ''}
|
||||
className="pt-2"
|
||||
onChange={value =>
|
||||
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
|
||||
}
|
||||
@@ -452,6 +448,18 @@ export default function SimpleJob({
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
{modelArch?.additionalSections?.includes('datasets.num_frames') && (
|
||||
<NumberInput
|
||||
label="Num Frames"
|
||||
className="pt-2"
|
||||
docKey="datasets.num_frames"
|
||||
value={dataset.num_frames}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
|
||||
placeholder="eg. 41"
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<FormGroup label="Settings" className="">
|
||||
@@ -586,6 +594,7 @@ export default function SimpleJob({
|
||||
value={jobConfig.config.process[0].sample.num_frames}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||
placeholder="eg. 0"
|
||||
className="pt-2"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
@@ -594,6 +603,7 @@ export default function SimpleJob({
|
||||
value={jobConfig.config.process[0].sample.fps}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||
placeholder="eg. 0"
|
||||
className="pt-2"
|
||||
min={0}
|
||||
required
|
||||
/>
|
||||
|
||||
@@ -12,7 +12,9 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
is_reg: false,
|
||||
network_weight: 1,
|
||||
resolution: [512, 768, 1024],
|
||||
controls: []
|
||||
controls: [],
|
||||
shrink_video_to_frames: true,
|
||||
num_frames: 1,
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import { GroupedSelectOption } from "@/types";
|
||||
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img'
|
||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
name: string;
|
||||
label: string;
|
||||
group: ModelGroup;
|
||||
controls?: Control[];
|
||||
isVideoModel?: boolean;
|
||||
defaults?: { [key: string]: any };
|
||||
@@ -16,10 +20,13 @@ export interface ModelArch {
|
||||
const defaultNameOrPath = '';
|
||||
|
||||
|
||||
|
||||
|
||||
export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flux',
|
||||
label: 'FLUX.1',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
|
||||
@@ -33,6 +40,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flux_kontext',
|
||||
label: 'FLUX.1-Kontext-dev',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
||||
@@ -48,6 +56,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flex1',
|
||||
label: 'Flex.1',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
|
||||
@@ -62,6 +71,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'flex2',
|
||||
label: 'Flex.2',
|
||||
group: 'image',
|
||||
controls: ['depth', 'line', 'pose', 'inpaint'],
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
@@ -89,6 +99,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'chroma',
|
||||
label: 'Chroma',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath],
|
||||
@@ -102,6 +113,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'wan21:1b',
|
||||
label: 'Wan 2.1 (1.3B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
@@ -110,14 +122,52 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames'],
|
||||
},
|
||||
{
|
||||
name: 'wan21_i2v:14b480p',
|
||||
label: 'Wan 2.1 I2V (14B-480P)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
||||
},
|
||||
{
|
||||
name: 'wan21_i2v:14b',
|
||||
label: 'Wan 2.1 I2V (14B-720P)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
||||
},
|
||||
{
|
||||
name: 'wan21:14b',
|
||||
label: 'Wan 2.1 (14B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
@@ -126,14 +176,16 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [40, 1],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames'],
|
||||
},
|
||||
{
|
||||
name: 'lumina2',
|
||||
label: 'Lumina2',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
|
||||
@@ -147,6 +199,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'hidream',
|
||||
label: 'HiDream',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
|
||||
@@ -163,6 +216,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'sdxl',
|
||||
label: 'SDXL',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
||||
@@ -177,6 +231,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'sd15',
|
||||
label: 'SD 1.5',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
||||
@@ -191,6 +246,7 @@ export const modelArchs: ModelArch[] = [
|
||||
{
|
||||
name: 'omnigen2',
|
||||
label: 'OmniGen2',
|
||||
group: 'image',
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
|
||||
@@ -206,3 +262,17 @@ export const modelArchs: ModelArch[] = [
|
||||
// Sort by label, case-insensitive
|
||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||
}) as any;
|
||||
|
||||
|
||||
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
|
||||
const group = acc.find(g => g.label === arch.group);
|
||||
if (group) {
|
||||
group.options.push({ value: arch.name, label: arch.label });
|
||||
} else {
|
||||
acc.push({
|
||||
label: arch.group,
|
||||
options: [{ value: arch.name, label: arch.label }],
|
||||
});
|
||||
}
|
||||
return acc;
|
||||
}, [] as GroupedSelectOption[]);
|
||||
|
||||
@@ -6,6 +6,7 @@ import dynamic from 'next/dynamic';
|
||||
import { CircleHelp } from 'lucide-react';
|
||||
import { getDoc } from '@/docs';
|
||||
import { openDoc } from '@/components/DocModal';
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
|
||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||
|
||||
@@ -141,13 +142,21 @@ export interface SelectInputProps extends InputProps {
|
||||
value: string;
|
||||
disabled?: boolean;
|
||||
onChange: (value: string) => void;
|
||||
options: { value: string; label: string }[];
|
||||
options: GroupedSelectOption[] | SelectOption[];
|
||||
}
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
const selectedOption = options.find(option => option.value === value);
|
||||
let selectedOption: SelectOption | undefined;
|
||||
if (options && options.length > 0) {
|
||||
// see if grouped options
|
||||
if ('options' in options[0]) {
|
||||
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
||||
} else {
|
||||
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||
}
|
||||
}
|
||||
return (
|
||||
<div
|
||||
className={classNames(props.className, {
|
||||
|
||||
@@ -57,6 +57,23 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.num_frames': {
|
||||
title: 'Number of Frames',
|
||||
description: (
|
||||
<>
|
||||
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
|
||||
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
|
||||
<br/>
|
||||
<br/>
|
||||
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
|
||||
So you would want all of your videos trimmed to around 5 seconds for best results.
|
||||
<br/>
|
||||
<br/>
|
||||
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
|
||||
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -84,6 +84,8 @@ export interface DatasetConfig {
|
||||
resolution: number[];
|
||||
controls: string[];
|
||||
control_path: string | null;
|
||||
num_frames: number;
|
||||
shrink_video_to_frames: boolean;
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
@@ -181,3 +183,12 @@ export interface ConfigDoc {
|
||||
title: string;
|
||||
description: React.ReactNode;
|
||||
}
|
||||
|
||||
export interface SelectOption {
|
||||
readonly value: string;
|
||||
readonly label: string;
|
||||
}
|
||||
export interface GroupedSelectOption {
|
||||
readonly label: string;
|
||||
readonly options: SelectOption[];
|
||||
}
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.5"
|
||||
VERSION = "0.3.6"
|
||||
Reference in New Issue
Block a user