Add ability to designate a dataset as i2v or t2v for models that support it

This commit is contained in:
Jaret Burkett
2025-08-06 09:29:47 -06:00
parent 1755e58dd9
commit 5d8922fca2
10 changed files with 146 additions and 84 deletions

View File

@@ -244,6 +244,7 @@ class Wan22Model(Wan21):
conditioned_latent = latent_model_input conditioned_latent = latent_model_input
noise_mask = None noise_mask = None
if batch.dataset_config.do_i2v:
with torch.no_grad(): with torch.no_grad():
frames = batch.tensor frames = batch.tensor
if len(frames.shape) == 4: if len(frames.shape) == 4:

View File

@@ -881,6 +881,8 @@ class DatasetConfig:
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
self.fast_image_size: bool = kwargs.get('fast_image_size', False) self.fast_image_size: bool = kwargs.get('fast_image_size', False)
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
""" """

View File

@@ -301,3 +301,10 @@ class DataLoaderBatchDTO:
del self.control_tensor del self.control_tensor
for file_item in self.file_items: for file_item in self.file_items:
file_item.cleanup() file_item.cleanup()
@property
def dataset_config(self) -> 'DatasetConfig':
if len(self.file_items) > 0:
return self.file_items[0].dataset_config
else:
return None

View File

@@ -524,6 +524,16 @@ export default function SimpleJob({
checked={dataset.is_reg || false} checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/> />
{
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)
}
</FormGroup> </FormGroup>
</div> </div>
<div> <div>

View File

@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
controls: [], controls: [],
shrink_video_to_frames: true, shrink_video_to_frames: true,
num_frames: 1, num_frames: 1,
do_i2v: true,
}; };
export const defaultJobConfig: JobConfig = { export const defaultJobConfig: JobConfig = {

View File

@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram'; type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
type ModelGroup = 'image' | 'video'; type ModelGroup = 'image' | 'video';
export interface ModelArch { export interface ModelArch {
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
}, },
{ {
name: 'lumina2', name: 'lumina2',

View File

@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
import { CircleHelp } from 'lucide-react'; import { CircleHelp } from 'lucide-react';
import { getDoc } from '@/docs'; import { getDoc } from '@/docs';
import { openDoc } from '@/components/DocModal'; import { openDoc } from '@/components/DocModal';
import { GroupedSelectOption, SelectOption } from '@/types'; import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
const Select = dynamic(() => import('react-select'), { ssr: false }); const Select = dynamic(() => import('react-select'), { ssr: false });
@@ -16,7 +16,8 @@ const inputClasses =
export interface InputProps { export interface InputProps {
label?: string; label?: string;
docKey?: string; docKey?: string | null;
doc?: ConfigDoc | null;
className?: string; className?: string;
placeholder?: string; placeholder?: string;
required?: boolean; required?: boolean;
@@ -29,9 +30,12 @@ export interface TextInputProps extends InputProps {
disabled?: boolean; disabled?: boolean;
} }
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>( export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => { const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
return ( return (
<div className={classNames(className)}> <div className={classNames(className)}>
{label && ( {label && (
@@ -58,8 +62,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
/> />
</div> </div>
); );
}, });
);
// 👇 Helpful for debugging // 👇 Helpful for debugging
TextInput.displayName = 'TextInput'; TextInput.displayName = 'TextInput';
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
export const NumberInput = (props: NumberInputProps) => { export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props; const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
// Add controlled internal state to properly handle partial inputs // Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState<string | number>(value ?? ''); const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
export const SelectInput = (props: SelectInputProps) => { export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options, docKey = null } = props; const { label, value, onChange, options, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
let selectedOption: SelectOption | undefined; let selectedOption: SelectOption | undefined;
if (options && options.length > 0) { if (options && options.length > 0) {
// see if grouped options // see if grouped options
if ('options' in options[0]) { if ('options' in options[0]) {
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value); selectedOption = (options as GroupedSelectOption[])
.flatMap(group => group.options)
.find(opt => opt.value === value);
} else { } else {
selectedOption = (options as SelectOption[]).find(opt => opt.value === value); selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
} }
@@ -196,10 +207,17 @@ export interface CheckboxProps {
className?: string; className?: string;
required?: boolean; required?: boolean;
disabled?: boolean; disabled?: boolean;
docKey?: string | null;
doc?: ConfigDoc | null;
} }
export const Checkbox = (props: CheckboxProps) => { export const Checkbox = (props: CheckboxProps) => {
const { label, checked, onChange, required, disabled } = props; const { label, checked, onChange, required, disabled } = props;
let { doc } = props;
if (!doc && props.docKey) {
doc = getDoc(props.docKey);
}
const id = React.useId(); const id = React.useId();
return ( return (
@@ -227,6 +245,7 @@ export const Checkbox = (props: CheckboxProps) => {
/> />
</button> </button>
{label && ( {label && (
<>
<label <label
htmlFor={id} htmlFor={id}
className={classNames( className={classNames(
@@ -236,6 +255,12 @@ export const Checkbox = (props: CheckboxProps) => {
> >
{label} {label}
</label> </label>
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)}
</>
)} )}
</div> </div>
); );
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
interface FormGroupProps { interface FormGroupProps {
label?: string; label?: string;
className?: string; className?: string;
docKey?: string; docKey?: string | null;
doc?: ConfigDoc | null;
children: React.ReactNode; children: React.ReactNode;
} }
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => { export const FormGroup: React.FC<FormGroupProps> = props => {
const doc = getDoc(docKey); const { label, className, children, docKey = null } = props;
let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
return ( return (
<div className={classNames(className)}> <div className={classNames(className)}>
{label && ( {label && (

View File

@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
</> </>
), ),
}, },
'datasets.do_i2v': {
title: 'Do I2V',
description: (
<>
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
</>
),
},
}; };
export const getDoc = (key: string | null | undefined): ConfigDoc | null => { export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -86,6 +86,7 @@ export interface DatasetConfig {
control_path: string | null; control_path: string | null;
num_frames: number; num_frames: number;
shrink_video_to_frames: boolean; shrink_video_to_frames: boolean;
do_i2v: boolean;
} }
export interface EMAConfig { export interface EMAConfig {

View File

@@ -1 +1 @@
VERSION = "0.3.17" VERSION = "0.3.18"