mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add simple ui settings to train Wan i2v models.
This commit is contained in:
22
README.md
22
README.md
@@ -120,7 +120,7 @@ cd ai-toolkit
|
|||||||
python3 -m venv venv
|
python3 -m venv venv
|
||||||
source venv/bin/activate
|
source venv/bin/activate
|
||||||
# install torch first
|
# install torch first
|
||||||
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
|
pip3 install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ git clone https://github.com/ostris/ai-toolkit.git
|
|||||||
cd ai-toolkit
|
cd ai-toolkit
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
.\venv\Scripts\activate
|
.\venv\Scripts\activate
|
||||||
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu126
|
pip install --no-cache-dir torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu126
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -425,28 +425,32 @@ Everything else should work the same including layer targeting.
|
|||||||
|
|
||||||
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
|
Only larger updates are listed here. There are usually smaller daily updated that are omitted.
|
||||||
|
|
||||||
### June 29, 2024
|
### Jul 11, 2025
|
||||||
|
- Added better video config settings to the UI for video models.
|
||||||
|
- Added Wan I2V training to the UI
|
||||||
|
|
||||||
|
### June 29, 2025
|
||||||
- Fixed issue where Kontext forced sizes on sampling
|
- Fixed issue where Kontext forced sizes on sampling
|
||||||
|
|
||||||
### June 26, 2024
|
### June 26, 2025
|
||||||
- Added support for FLUX.1 Kontext training
|
- Added support for FLUX.1 Kontext training
|
||||||
- added support for instruction dataset training
|
- added support for instruction dataset training
|
||||||
|
|
||||||
### June 25, 2024
|
### June 25, 2025
|
||||||
- Added support for OmniGen2 training
|
- Added support for OmniGen2 training
|
||||||
-
|
-
|
||||||
### June 17, 2024
|
### June 17, 2025
|
||||||
- Performance optimizations for batch preparation
|
- Performance optimizations for batch preparation
|
||||||
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
|
- Added some docs via a popup for items in the simple ui explaining what settings do. Still a WIP
|
||||||
|
|
||||||
### June 16, 2024
|
### June 16, 2025
|
||||||
- Hide control images in the UI when viewing datasets
|
- Hide control images in the UI when viewing datasets
|
||||||
- WIP on mean flow loss
|
- WIP on mean flow loss
|
||||||
|
|
||||||
### June 12, 2024
|
### June 12, 2025
|
||||||
- Fixed issue that resulted in blank captions in the dataloader
|
- Fixed issue that resulted in blank captions in the dataloader
|
||||||
|
|
||||||
### June 10, 2024
|
### June 10, 2025
|
||||||
- Decided to keep track up updates in the readme
|
- Decided to keep track up updates in the readme
|
||||||
- Added support for SDXL in the UI
|
- Added support for SDXL in the UI
|
||||||
- Added support for SD 1.5 in the UI
|
- Added support for SD 1.5 in the UI
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ WORKDIR /app
|
|||||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||||
|
|
||||||
# install pytorch before cache bust to avoid redownloading pytorch
|
# install pytorch before cache bust to avoid redownloading pytorch
|
||||||
RUN pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
|
RUN pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||||
|
|
||||||
# Fix cache busting by moving CACHEBUST to right before git clone
|
# Fix cache busting by moving CACHEBUST to right before git clone
|
||||||
ARG CACHEBUST=1234
|
ARG CACHEBUST=1234
|
||||||
@@ -63,7 +63,7 @@ WORKDIR /app/ai-toolkit
|
|||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
RUN pip install --no-cache-dir -r requirements.txt && \
|
RUN pip install --no-cache-dir -r requirements.txt && \
|
||||||
pip install --pre --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
|
pip install --pre --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 --force && \
|
||||||
pip install setuptools==69.5.1 --no-cache-dir
|
pip install setuptools==69.5.1 --no-cache-dir
|
||||||
|
|
||||||
# Build UI
|
# Build UI
|
||||||
|
|||||||
1
run.py
1
run.py
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
||||||
import sys
|
import sys
|
||||||
from typing import Union, OrderedDict
|
from typing import Union, OrderedDict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
'use client';
|
'use client';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { modelArchs, ModelArch } from './options';
|
import { modelArchs, ModelArch, groupedModelOptions } from './options';
|
||||||
import { defaultDatasetConfig } from './jobConfig';
|
import { defaultDatasetConfig } from './jobConfig';
|
||||||
import { JobConfig } from '@/types';
|
import { JobConfig } from '@/types';
|
||||||
import { objectCopy } from '@/utils/basic';
|
import { objectCopy } from '@/utils/basic';
|
||||||
@@ -37,7 +37,7 @@ export default function SimpleJob({
|
|||||||
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
return modelArchs.find(a => a.name === jobConfig.config.process[0].model.arch) as ModelArch;
|
||||||
}, [jobConfig.config.process[0].model.arch]);
|
}, [jobConfig.config.process[0].model.arch]);
|
||||||
|
|
||||||
const isVideoModel = !!modelArch?.isVideoModel;
|
const isVideoModel = !!(modelArch?.group === 'video');
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -100,8 +100,9 @@ export default function SimpleJob({
|
|||||||
// set new model
|
// set new model
|
||||||
setJobConfig(value, 'config.process[0].model.arch');
|
setJobConfig(value, 'config.process[0].model.arch');
|
||||||
|
|
||||||
// update controls for datasets
|
// update datasets
|
||||||
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
const hasControlPath = newArch?.additionalSections?.includes('datasets.control_path') || false;
|
||||||
|
const hasNumFrames = newArch?.additionalSections?.includes('datasets.num_frames') || false;
|
||||||
const controls = newArch?.controls ?? [];
|
const controls = newArch?.controls ?? [];
|
||||||
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
const datasets = jobConfig.config.process[0].datasets.map(dataset => {
|
||||||
const newDataset = objectCopy(dataset);
|
const newDataset = objectCopy(dataset);
|
||||||
@@ -109,20 +110,14 @@ export default function SimpleJob({
|
|||||||
if (!hasControlPath) {
|
if (!hasControlPath) {
|
||||||
newDataset.control_path = null; // reset control path if not applicable
|
newDataset.control_path = null; // reset control path if not applicable
|
||||||
}
|
}
|
||||||
|
if (!hasNumFrames) {
|
||||||
|
newDataset.num_frames = 1; // reset num_frames if not applicable
|
||||||
|
}
|
||||||
return newDataset;
|
return newDataset;
|
||||||
});
|
});
|
||||||
setJobConfig(datasets, 'config.process[0].datasets');
|
setJobConfig(datasets, 'config.process[0].datasets');
|
||||||
}}
|
}}
|
||||||
options={
|
options={groupedModelOptions}
|
||||||
modelArchs
|
|
||||||
.map(model => {
|
|
||||||
return {
|
|
||||||
value: model.name,
|
|
||||||
label: model.label,
|
|
||||||
};
|
|
||||||
})
|
|
||||||
.filter(x => x) as { value: string; label: string }[]
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
<TextInput
|
<TextInput
|
||||||
label="Name or Path"
|
label="Name or Path"
|
||||||
@@ -422,6 +417,7 @@ export default function SimpleJob({
|
|||||||
label="Control Dataset"
|
label="Control Dataset"
|
||||||
docKey="datasets.control_path"
|
docKey="datasets.control_path"
|
||||||
value={dataset.control_path ?? ''}
|
value={dataset.control_path ?? ''}
|
||||||
|
className="pt-2"
|
||||||
onChange={value =>
|
onChange={value =>
|
||||||
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
|
setJobConfig(value == '' ? null : value, `config.process[0].datasets[${i}].control_path`)
|
||||||
}
|
}
|
||||||
@@ -452,6 +448,18 @@ export default function SimpleJob({
|
|||||||
min={0}
|
min={0}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
|
{modelArch?.additionalSections?.includes('datasets.num_frames') && (
|
||||||
|
<NumberInput
|
||||||
|
label="Num Frames"
|
||||||
|
className="pt-2"
|
||||||
|
docKey="datasets.num_frames"
|
||||||
|
value={dataset.num_frames}
|
||||||
|
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].num_frames`)}
|
||||||
|
placeholder="eg. 41"
|
||||||
|
min={1}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<FormGroup label="Settings" className="">
|
<FormGroup label="Settings" className="">
|
||||||
@@ -586,6 +594,7 @@ export default function SimpleJob({
|
|||||||
value={jobConfig.config.process[0].sample.num_frames}
|
value={jobConfig.config.process[0].sample.num_frames}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
onChange={value => setJobConfig(value, 'config.process[0].sample.num_frames')}
|
||||||
placeholder="eg. 0"
|
placeholder="eg. 0"
|
||||||
|
className="pt-2"
|
||||||
min={0}
|
min={0}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
@@ -594,6 +603,7 @@ export default function SimpleJob({
|
|||||||
value={jobConfig.config.process[0].sample.fps}
|
value={jobConfig.config.process[0].sample.fps}
|
||||||
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
onChange={value => setJobConfig(value, 'config.process[0].sample.fps')}
|
||||||
placeholder="eg. 0"
|
placeholder="eg. 0"
|
||||||
|
className="pt-2"
|
||||||
min={0}
|
min={0}
|
||||||
required
|
required
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ export const defaultDatasetConfig: DatasetConfig = {
|
|||||||
is_reg: false,
|
is_reg: false,
|
||||||
network_weight: 1,
|
network_weight: 1,
|
||||||
resolution: [512, 768, 1024],
|
resolution: [512, 768, 1024],
|
||||||
controls: []
|
controls: [],
|
||||||
|
shrink_video_to_frames: true,
|
||||||
|
num_frames: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultJobConfig: JobConfig = {
|
export const defaultJobConfig: JobConfig = {
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
|
import { GroupedSelectOption } from "@/types";
|
||||||
|
|
||||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||||
|
|
||||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img'
|
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames';
|
||||||
|
type ModelGroup = 'image' | 'video';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
name: string;
|
name: string;
|
||||||
label: string;
|
label: string;
|
||||||
|
group: ModelGroup;
|
||||||
controls?: Control[];
|
controls?: Control[];
|
||||||
isVideoModel?: boolean;
|
isVideoModel?: boolean;
|
||||||
defaults?: { [key: string]: any };
|
defaults?: { [key: string]: any };
|
||||||
@@ -16,10 +20,13 @@ export interface ModelArch {
|
|||||||
const defaultNameOrPath = '';
|
const defaultNameOrPath = '';
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export const modelArchs: ModelArch[] = [
|
export const modelArchs: ModelArch[] = [
|
||||||
{
|
{
|
||||||
name: 'flux',
|
name: 'flux',
|
||||||
label: 'FLUX.1',
|
label: 'FLUX.1',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-dev', defaultNameOrPath],
|
||||||
@@ -33,6 +40,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'flux_kontext',
|
name: 'flux_kontext',
|
||||||
label: 'FLUX.1-Kontext-dev',
|
label: 'FLUX.1-Kontext-dev',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['black-forest-labs/FLUX.1-Kontext-dev', defaultNameOrPath],
|
||||||
@@ -48,6 +56,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'flex1',
|
name: 'flex1',
|
||||||
label: 'Flex.1',
|
label: 'Flex.1',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['ostris/Flex.1-alpha', defaultNameOrPath],
|
||||||
@@ -62,6 +71,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'flex2',
|
name: 'flex2',
|
||||||
label: 'Flex.2',
|
label: 'Flex.2',
|
||||||
|
group: 'image',
|
||||||
controls: ['depth', 'line', 'pose', 'inpaint'],
|
controls: ['depth', 'line', 'pose', 'inpaint'],
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
@@ -89,6 +99,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'chroma',
|
name: 'chroma',
|
||||||
label: 'Chroma',
|
label: 'Chroma',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['lodestones/Chroma', defaultNameOrPath],
|
||||||
@@ -102,6 +113,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'wan21:1b',
|
name: 'wan21:1b',
|
||||||
label: 'Wan 2.1 (1.3B)',
|
label: 'Wan 2.1 (1.3B)',
|
||||||
|
group: 'video',
|
||||||
isVideoModel: true,
|
isVideoModel: true,
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
@@ -110,14 +122,52 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].model.quantize_te': [true, false],
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].sample.num_frames': [40, 1],
|
'config.process[0].sample.num_frames': [41, 1],
|
||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: ['datasets.num_frames'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'wan21_i2v:14b480p',
|
||||||
|
label: 'Wan 2.1 I2V (14B-480P)',
|
||||||
|
group: 'video',
|
||||||
|
isVideoModel: true,
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-480P-Diffusers', defaultNameOrPath],
|
||||||
|
'config.process[0].model.quantize': [true, false],
|
||||||
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].sample.num_frames': [41, 1],
|
||||||
|
'config.process[0].sample.fps': [15, 1],
|
||||||
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'wan21_i2v:14b',
|
||||||
|
label: 'Wan 2.1 I2V (14B-720P)',
|
||||||
|
group: 'video',
|
||||||
|
isVideoModel: true,
|
||||||
|
defaults: {
|
||||||
|
// default updates when [selected, unselected] in the UI
|
||||||
|
'config.process[0].model.name_or_path': ['Wan-AI/Wan2.1-I2V-14B-720P-Diffusers', defaultNameOrPath],
|
||||||
|
'config.process[0].model.quantize': [true, false],
|
||||||
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
|
'config.process[0].sample.num_frames': [41, 1],
|
||||||
|
'config.process[0].sample.fps': [15, 1],
|
||||||
|
},
|
||||||
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'wan21:14b',
|
name: 'wan21:14b',
|
||||||
label: 'Wan 2.1 (14B)',
|
label: 'Wan 2.1 (14B)',
|
||||||
|
group: 'video',
|
||||||
isVideoModel: true,
|
isVideoModel: true,
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
@@ -126,14 +176,16 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].model.quantize_te': [true, false],
|
'config.process[0].model.quantize_te': [true, false],
|
||||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||||
'config.process[0].sample.num_frames': [40, 1],
|
'config.process[0].sample.num_frames': [41, 1],
|
||||||
'config.process[0].sample.fps': [15, 1],
|
'config.process[0].sample.fps': [15, 1],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
|
additionalSections: ['datasets.num_frames'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
label: 'Lumina2',
|
label: 'Lumina2',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['Alpha-VLLM/Lumina-Image-2.0', defaultNameOrPath],
|
||||||
@@ -147,6 +199,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'hidream',
|
name: 'hidream',
|
||||||
label: 'HiDream',
|
label: 'HiDream',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['HiDream-ai/HiDream-I1-Full', defaultNameOrPath],
|
||||||
@@ -163,6 +216,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'sdxl',
|
name: 'sdxl',
|
||||||
label: 'SDXL',
|
label: 'SDXL',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['stabilityai/stable-diffusion-xl-base-1.0', defaultNameOrPath],
|
||||||
@@ -177,6 +231,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'sd15',
|
name: 'sd15',
|
||||||
label: 'SD 1.5',
|
label: 'SD 1.5',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['stable-diffusion-v1-5/stable-diffusion-v1-5', defaultNameOrPath],
|
||||||
@@ -191,6 +246,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
{
|
{
|
||||||
name: 'omnigen2',
|
name: 'omnigen2',
|
||||||
label: 'OmniGen2',
|
label: 'OmniGen2',
|
||||||
|
group: 'image',
|
||||||
defaults: {
|
defaults: {
|
||||||
// default updates when [selected, unselected] in the UI
|
// default updates when [selected, unselected] in the UI
|
||||||
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
|
'config.process[0].model.name_or_path': ['OmniGen2/OmniGen2', defaultNameOrPath],
|
||||||
@@ -206,3 +262,17 @@ export const modelArchs: ModelArch[] = [
|
|||||||
// Sort by label, case-insensitive
|
// Sort by label, case-insensitive
|
||||||
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
return a.label.localeCompare(b.label, undefined, { sensitivity: 'base' })
|
||||||
}) as any;
|
}) as any;
|
||||||
|
|
||||||
|
|
||||||
|
export const groupedModelOptions: GroupedSelectOption[] = modelArchs.reduce((acc, arch) => {
|
||||||
|
const group = acc.find(g => g.label === arch.group);
|
||||||
|
if (group) {
|
||||||
|
group.options.push({ value: arch.name, label: arch.label });
|
||||||
|
} else {
|
||||||
|
acc.push({
|
||||||
|
label: arch.group,
|
||||||
|
options: [{ value: arch.name, label: arch.label }],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return acc;
|
||||||
|
}, [] as GroupedSelectOption[]);
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import dynamic from 'next/dynamic';
|
|||||||
import { CircleHelp } from 'lucide-react';
|
import { CircleHelp } from 'lucide-react';
|
||||||
import { getDoc } from '@/docs';
|
import { getDoc } from '@/docs';
|
||||||
import { openDoc } from '@/components/DocModal';
|
import { openDoc } from '@/components/DocModal';
|
||||||
|
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||||
|
|
||||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||||
|
|
||||||
@@ -141,13 +142,21 @@ export interface SelectInputProps extends InputProps {
|
|||||||
value: string;
|
value: string;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
onChange: (value: string) => void;
|
onChange: (value: string) => void;
|
||||||
options: { value: string; label: string }[];
|
options: GroupedSelectOption[] | SelectOption[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export const SelectInput = (props: SelectInputProps) => {
|
export const SelectInput = (props: SelectInputProps) => {
|
||||||
const { label, value, onChange, options, docKey = null } = props;
|
const { label, value, onChange, options, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
const doc = getDoc(docKey);
|
||||||
const selectedOption = options.find(option => option.value === value);
|
let selectedOption: SelectOption | undefined;
|
||||||
|
if (options && options.length > 0) {
|
||||||
|
// see if grouped options
|
||||||
|
if ('options' in options[0]) {
|
||||||
|
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
||||||
|
} else {
|
||||||
|
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||||
|
}
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={classNames(props.className, {
|
className={classNames(props.className, {
|
||||||
|
|||||||
@@ -57,6 +57,23 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
'datasets.num_frames': {
|
||||||
|
title: 'Number of Frames',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
|
||||||
|
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
|
||||||
|
<br/>
|
||||||
|
<br/>
|
||||||
|
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
|
||||||
|
So you would want all of your videos trimmed to around 5 seconds for best results.
|
||||||
|
<br/>
|
||||||
|
<br/>
|
||||||
|
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
|
||||||
|
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ export interface DatasetConfig {
|
|||||||
resolution: number[];
|
resolution: number[];
|
||||||
controls: string[];
|
controls: string[];
|
||||||
control_path: string | null;
|
control_path: string | null;
|
||||||
|
num_frames: number;
|
||||||
|
shrink_video_to_frames: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface EMAConfig {
|
export interface EMAConfig {
|
||||||
@@ -181,3 +183,12 @@ export interface ConfigDoc {
|
|||||||
title: string;
|
title: string;
|
||||||
description: React.ReactNode;
|
description: React.ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface SelectOption {
|
||||||
|
readonly value: string;
|
||||||
|
readonly label: string;
|
||||||
|
}
|
||||||
|
export interface GroupedSelectOption {
|
||||||
|
readonly label: string;
|
||||||
|
readonly options: SelectOption[];
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.3.5"
|
VERSION = "0.3.6"
|
||||||
Reference in New Issue
Block a user