Add ability to designate a dataset as i2v or t2v for models that support it

This commit is contained in:
Jaret Burkett
2025-08-06 09:29:47 -06:00
parent 1755e58dd9
commit 5d8922fca2
10 changed files with 146 additions and 84 deletions

View File

@@ -243,41 +243,42 @@ class Wan22Model(Wan21):
# for wan, only do i2v for video for now. Images do normal t2i # for wan, only do i2v for video for now. Images do normal t2i
conditioned_latent = latent_model_input conditioned_latent = latent_model_input
noise_mask = None noise_mask = None
if batch.dataset_config.do_i2v:
with torch.no_grad():
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
# Add conditioning using the standalone function
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
latent_model_input=latent_model_input.to(
self.device_torch, self.torch_dtype
),
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
vae=self.vae,
)
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
with torch.no_grad(): # make the noise mask
frames = batch.tensor if noise_mask is None:
if len(frames.shape) == 4: noise_mask = torch.ones(
first_frames = frames conditioned_latent.shape,
elif len(frames.shape) == 5: dtype=conditioned_latent.dtype,
first_frames = frames[:, 0] device=conditioned_latent.device,
# Add conditioning using the standalone function )
conditioned_latent, noise_mask = add_first_frame_conditioning_v22( # todo write this better
latent_model_input=latent_model_input.to( t_chunks = torch.chunk(timestep, timestep.shape[0])
self.device_torch, self.torch_dtype out_t_chunks = []
), for t in t_chunks:
first_frame=first_frames.to(self.device_torch, self.torch_dtype), # seq_len: num_latent_frames * latent_height//2 * latent_width//2
vae=self.vae, temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
) # batch_size, seq_len
else: temp_ts = temp_ts.unsqueeze(0)
raise ValueError(f"Unknown frame shape {frames.shape}") out_t_chunks.append(temp_ts)
timestep = torch.cat(out_t_chunks, dim=0)
# make the noise mask
if noise_mask is None:
noise_mask = torch.ones(
conditioned_latent.shape,
dtype=conditioned_latent.dtype,
device=conditioned_latent.device,
)
# todo write this better
t_chunks = torch.chunk(timestep, timestep.shape[0])
out_t_chunks = []
for t in t_chunks:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
temp_ts = temp_ts.unsqueeze(0)
out_t_chunks.append(temp_ts)
timestep = torch.cat(out_t_chunks, dim=0)
noise_pred = self.model( noise_pred = self.model(
hidden_states=conditioned_latent, hidden_states=conditioned_latent,

View File

@@ -880,6 +880,8 @@ class DatasetConfig:
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
self.fast_image_size: bool = kwargs.get('fast_image_size', False) self.fast_image_size: bool = kwargs.get('fast_image_size', False)
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:

View File

@@ -301,3 +301,10 @@ class DataLoaderBatchDTO:
del self.control_tensor del self.control_tensor
for file_item in self.file_items: for file_item in self.file_items:
file_item.cleanup() file_item.cleanup()
@property
def dataset_config(self) -> 'DatasetConfig':
if len(self.file_items) > 0:
return self.file_items[0].dataset_config
else:
return None

View File

@@ -524,6 +524,16 @@ export default function SimpleJob({
checked={dataset.is_reg || false} checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/> />
{
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)
}
</FormGroup> </FormGroup>
</div> </div>
<div> <div>

View File

@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
controls: [], controls: [],
shrink_video_to_frames: true, shrink_video_to_frames: true,
num_frames: 1, num_frames: 1,
do_i2v: true,
}; };
export const defaultJobConfig: JobConfig = { export const defaultJobConfig: JobConfig = {

View File

@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram'; type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
type ModelGroup = 'image' | 'video'; type ModelGroup = 'image' | 'video';
export interface ModelArch { export interface ModelArch {
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
}, },
disableSections: ['network.conv'], disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
}, },
{ {
name: 'lumina2', name: 'lumina2',

View File

@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
import { CircleHelp } from 'lucide-react'; import { CircleHelp } from 'lucide-react';
import { getDoc } from '@/docs'; import { getDoc } from '@/docs';
import { openDoc } from '@/components/DocModal'; import { openDoc } from '@/components/DocModal';
import { GroupedSelectOption, SelectOption } from '@/types'; import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
const Select = dynamic(() => import('react-select'), { ssr: false }); const Select = dynamic(() => import('react-select'), { ssr: false });
@@ -16,7 +16,8 @@ const inputClasses =
export interface InputProps { export interface InputProps {
label?: string; label?: string;
docKey?: string; docKey?: string | null;
doc?: ConfigDoc | null;
className?: string; className?: string;
placeholder?: string; placeholder?: string;
required?: boolean; required?: boolean;
@@ -29,37 +30,39 @@ export interface TextInputProps extends InputProps {
disabled?: boolean; disabled?: boolean;
} }
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>( export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => { const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
return ( if (!doc && docKey) {
<div className={classNames(className)}> doc = getDoc(docKey);
{label && ( }
<label className={labelClasses}> return (
{label}{' '} <div className={classNames(className)}>
{doc && ( {label && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}> <label className={labelClasses}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" /> {label}{' '}
</div> {doc && (
)} <div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
</label> <CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
)} </div>
<input )}
ref={ref} </label>
type={type} )}
value={value} <input
onChange={e => { ref={ref}
if (!disabled) onChange(e.target.value); type={type}
}} value={value}
className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`} onChange={e => {
placeholder={placeholder} if (!disabled) onChange(e.target.value);
required={required} }}
disabled={disabled} className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`}
/> placeholder={placeholder}
</div> required={required}
); disabled={disabled}
}, />
); </div>
);
});
// 👇 Helpful for debugging // 👇 Helpful for debugging
TextInput.displayName = 'TextInput'; TextInput.displayName = 'TextInput';
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
export const NumberInput = (props: NumberInputProps) => { export const NumberInput = (props: NumberInputProps) => {
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props; const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
// Add controlled internal state to properly handle partial inputs // Add controlled internal state to properly handle partial inputs
const [inputValue, setInputValue] = React.useState<string | number>(value ?? ''); const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
export const SelectInput = (props: SelectInputProps) => { export const SelectInput = (props: SelectInputProps) => {
const { label, value, onChange, options, docKey = null } = props; const { label, value, onChange, options, docKey = null } = props;
const doc = getDoc(docKey); let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
let selectedOption: SelectOption | undefined; let selectedOption: SelectOption | undefined;
if (options && options.length > 0) { if (options && options.length > 0) {
// see if grouped options // see if grouped options
if ('options' in options[0]) { if ('options' in options[0]) {
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value); selectedOption = (options as GroupedSelectOption[])
.flatMap(group => group.options)
.find(opt => opt.value === value);
} else { } else {
selectedOption = (options as SelectOption[]).find(opt => opt.value === value); selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
} }
@@ -196,10 +207,17 @@ export interface CheckboxProps {
className?: string; className?: string;
required?: boolean; required?: boolean;
disabled?: boolean; disabled?: boolean;
docKey?: string | null;
doc?: ConfigDoc | null;
} }
export const Checkbox = (props: CheckboxProps) => { export const Checkbox = (props: CheckboxProps) => {
const { label, checked, onChange, required, disabled } = props; const { label, checked, onChange, required, disabled } = props;
let { doc } = props;
if (!doc && props.docKey) {
doc = getDoc(props.docKey);
}
const id = React.useId(); const id = React.useId();
return ( return (
@@ -227,15 +245,22 @@ export const Checkbox = (props: CheckboxProps) => {
/> />
</button> </button>
{label && ( {label && (
<label <>
htmlFor={id} <label
className={classNames( htmlFor={id}
'text-sm font-medium cursor-pointer select-none', className={classNames(
disabled ? 'text-gray-500' : 'text-gray-300', 'text-sm font-medium cursor-pointer select-none',
disabled ? 'text-gray-500' : 'text-gray-300',
)}
>
{label}
</label>
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
</div>
)} )}
> </>
{label}
</label>
)} )}
</div> </div>
); );
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
interface FormGroupProps { interface FormGroupProps {
label?: string; label?: string;
className?: string; className?: string;
docKey?: string; docKey?: string | null;
doc?: ConfigDoc | null;
children: React.ReactNode; children: React.ReactNode;
} }
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => { export const FormGroup: React.FC<FormGroupProps> = props => {
const doc = getDoc(docKey); const { label, className, children, docKey = null } = props;
let { doc } = props;
if (!doc && docKey) {
doc = getDoc(docKey);
}
return ( return (
<div className={classNames(className)}> <div className={classNames(className)}>
{label && ( {label && (

View File

@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
</> </>
), ),
}, },
'datasets.do_i2v': {
title: 'Do I2V',
description: (
<>
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
</>
),
},
}; };
export const getDoc = (key: string | null | undefined): ConfigDoc | null => { export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -86,6 +86,7 @@ export interface DatasetConfig {
control_path: string | null; control_path: string | null;
num_frames: number; num_frames: number;
shrink_video_to_frames: boolean; shrink_video_to_frames: boolean;
do_i2v: boolean;
} }
export interface EMAConfig { export interface EMAConfig {

View File

@@ -1 +1 @@
VERSION = "0.3.17" VERSION = "0.3.18"