From 5d8922fca2a039e655e470e93f62fc21d259f889 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 6 Aug 2025 09:29:47 -0600 Subject: [PATCH] Add ability to designate a dataset as i2v or t2v for models that support it --- .../diffusion_models/wan22/wan22_model.py | 69 +++++----- toolkit/config_modules.py | 2 + toolkit/data_transfer_object/data_loader.py | 7 + ui/src/app/jobs/new/SimpleJob.tsx | 10 ++ ui/src/app/jobs/new/jobConfig.ts | 1 + ui/src/app/jobs/new/options.ts | 4 +- ui/src/components/formInputs.tsx | 124 +++++++++++------- ui/src/docs.tsx | 10 ++ ui/src/types.ts | 1 + version.py | 2 +- 10 files changed, 146 insertions(+), 84 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_model.py b/extensions_built_in/diffusion_models/wan22/wan22_model.py index 0b0c63ea..328ea807 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_model.py @@ -243,41 +243,42 @@ class Wan22Model(Wan21): # for wan, only do i2v for video for now. Images do normal t2i conditioned_latent = latent_model_input noise_mask = None + + if batch.dataset_config.do_i2v: + with torch.no_grad(): + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + # Add conditioning using the standalone function + conditioned_latent, noise_mask = add_first_frame_conditioning_v22( + latent_model_input=latent_model_input.to( + self.device_torch, self.torch_dtype + ), + first_frame=first_frames.to(self.device_torch, self.torch_dtype), + vae=self.vae, + ) + else: + raise ValueError(f"Unknown frame shape {frames.shape}") - with torch.no_grad(): - frames = batch.tensor - if len(frames.shape) == 4: - first_frames = frames - elif len(frames.shape) == 5: - first_frames = frames[:, 0] - # Add conditioning using the standalone function - conditioned_latent, noise_mask = add_first_frame_conditioning_v22( - latent_model_input=latent_model_input.to( - self.device_torch, self.torch_dtype - ), - first_frame=first_frames.to(self.device_torch, self.torch_dtype), - vae=self.vae, - ) - else: - raise ValueError(f"Unknown frame shape {frames.shape}") - - # make the noise mask - if noise_mask is None: - noise_mask = torch.ones( - conditioned_latent.shape, - dtype=conditioned_latent.dtype, - device=conditioned_latent.device, - ) - # todo write this better - t_chunks = torch.chunk(timestep, timestep.shape[0]) - out_t_chunks = [] - for t in t_chunks: - # seq_len: num_latent_frames * latent_height//2 * latent_width//2 - temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() - # batch_size, seq_len - temp_ts = temp_ts.unsqueeze(0) - out_t_chunks.append(temp_ts) - timestep = torch.cat(out_t_chunks, dim=0) + # make the noise mask + if noise_mask is None: + noise_mask = torch.ones( + conditioned_latent.shape, + dtype=conditioned_latent.dtype, + device=conditioned_latent.device, + ) + # todo write this better + t_chunks = torch.chunk(timestep, timestep.shape[0]) + out_t_chunks = [] + for t in t_chunks: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + temp_ts = temp_ts.unsqueeze(0) + out_t_chunks.append(temp_ts) + timestep = torch.cat(out_t_chunks, dim=0) noise_pred = self.model( hidden_states=conditioned_latent, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index adf5dc6f..98454b74 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -880,6 +880,8 @@ class DatasetConfig: # if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing self.fast_image_size: bool = kwargs.get('fast_image_size', False) + + self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 9078fbbb..bcc6c918 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -301,3 +301,10 @@ class DataLoaderBatchDTO: del self.control_tensor for file_item in self.file_items: file_item.cleanup() + + @property + def dataset_config(self) -> 'DatasetConfig': + if len(self.file_items) > 0: + return self.file_items[0].dataset_config + else: + return None diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index dcf38314..016070bc 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -524,6 +524,16 @@ export default function SimpleJob({ checked={dataset.is_reg || false} onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)} /> + { + modelArch?.additionalSections?.includes('datasets.do_i2v') && ( + setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)} + docKey="datasets.do_i2v" + /> + ) + }
diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index 28155082..ba24e7e4 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = { controls: [], shrink_video_to_frames: true, num_frames: 1, + do_i2v: true, }; export const defaultJobConfig: JobConfig = { diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index bd508ecf..b2af8b86 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types'; type Control = 'depth' | 'line' | 'pose' | 'inpaint'; type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv'; -type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram'; +type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram'; type ModelGroup = 'image' | 'video'; export interface ModelArch { @@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [ 'config.process[0].train.timestep_type': ['weighted', 'sigmoid'], }, disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'], }, { name: 'lumina2', diff --git a/ui/src/components/formInputs.tsx b/ui/src/components/formInputs.tsx index 5613cb4e..3f4616ca 100644 --- a/ui/src/components/formInputs.tsx +++ b/ui/src/components/formInputs.tsx @@ -6,7 +6,7 @@ import dynamic from 'next/dynamic'; import { CircleHelp } from 'lucide-react'; import { getDoc } from '@/docs'; import { openDoc } from '@/components/DocModal'; -import { GroupedSelectOption, SelectOption } from '@/types'; +import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types'; const Select = dynamic(() => import('react-select'), { ssr: false }); @@ -16,7 +16,8 @@ const inputClasses = export interface InputProps { label?: string; - docKey?: string; + docKey?: string | null; + doc?: ConfigDoc | null; className?: string; placeholder?: string; required?: boolean; @@ -29,37 +30,39 @@ export interface TextInputProps extends InputProps { disabled?: boolean; } -export const TextInput = forwardRef( - ({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => { - const doc = getDoc(docKey); - return ( -
- {label && ( - - )} - { - if (!disabled) onChange(e.target.value); - }} - className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`} - placeholder={placeholder} - required={required} - disabled={disabled} - /> -
- ); - }, -); +export const TextInput = forwardRef((props: TextInputProps, ref) => { + const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props; + let { doc } = props; + if (!doc && docKey) { + doc = getDoc(docKey); + } + return ( +
+ {label && ( + + )} + { + if (!disabled) onChange(e.target.value); + }} + className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`} + placeholder={placeholder} + required={required} + disabled={disabled} + /> +
+ ); +}); // 👇 Helpful for debugging TextInput.displayName = 'TextInput'; @@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps { export const NumberInput = (props: NumberInputProps) => { const { label, value, onChange, placeholder, required, min, max, docKey = null } = props; - const doc = getDoc(docKey); + let { doc } = props; + if (!doc && docKey) { + doc = getDoc(docKey); + } // Add controlled internal state to properly handle partial inputs const [inputValue, setInputValue] = React.useState(value ?? ''); @@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps { export const SelectInput = (props: SelectInputProps) => { const { label, value, onChange, options, docKey = null } = props; - const doc = getDoc(docKey); + let { doc } = props; + if (!doc && docKey) { + doc = getDoc(docKey); + } let selectedOption: SelectOption | undefined; if (options && options.length > 0) { // see if grouped options if ('options' in options[0]) { - selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value); + selectedOption = (options as GroupedSelectOption[]) + .flatMap(group => group.options) + .find(opt => opt.value === value); } else { selectedOption = (options as SelectOption[]).find(opt => opt.value === value); } @@ -196,10 +207,17 @@ export interface CheckboxProps { className?: string; required?: boolean; disabled?: boolean; + docKey?: string | null; + doc?: ConfigDoc | null; } export const Checkbox = (props: CheckboxProps) => { const { label, checked, onChange, required, disabled } = props; + let { doc } = props; + if (!doc && props.docKey) { + doc = getDoc(props.docKey); + } + const id = React.useId(); return ( @@ -227,15 +245,22 @@ export const Checkbox = (props: CheckboxProps) => { /> {label && ( - + )}
); @@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => { interface FormGroupProps { label?: string; className?: string; - docKey?: string; + docKey?: string | null; + doc?: ConfigDoc | null; children: React.ReactNode; } -export const FormGroup: React.FC = ({ label, className, children, docKey = null }) => { - const doc = getDoc(docKey); +export const FormGroup: React.FC = props => { + const { label, className, children, docKey = null } = props; + let { doc } = props; + if (!doc && docKey) { + doc = getDoc(docKey); + } return (
{label && ( diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index 69f2032a..88d6c479 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'datasets.do_i2v': { + title: 'Do I2V', + description: ( + <> + For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset + to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image + for the video. If this option is not set, the dataset will be treated as a T2V dataset. + + ), + }, }; export const getDoc = (key: string | null | undefined): ConfigDoc | null => { diff --git a/ui/src/types.ts b/ui/src/types.ts index ef0007dc..08b22c8a 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -86,6 +86,7 @@ export interface DatasetConfig { control_path: string | null; num_frames: number; shrink_video_to_frames: boolean; + do_i2v: boolean; } export interface EMAConfig { diff --git a/version.py b/version.py index 08d5be86..be340cc3 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.3.17" \ No newline at end of file +VERSION = "0.3.18" \ No newline at end of file