mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add ability to designate a dataset as i2v or t2v for models that support it
This commit is contained in:
@@ -243,41 +243,42 @@ class Wan22Model(Wan21):
|
|||||||
# for wan, only do i2v for video for now. Images do normal t2i
|
# for wan, only do i2v for video for now. Images do normal t2i
|
||||||
conditioned_latent = latent_model_input
|
conditioned_latent = latent_model_input
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
|
|
||||||
|
if batch.dataset_config.do_i2v:
|
||||||
|
with torch.no_grad():
|
||||||
|
frames = batch.tensor
|
||||||
|
if len(frames.shape) == 4:
|
||||||
|
first_frames = frames
|
||||||
|
elif len(frames.shape) == 5:
|
||||||
|
first_frames = frames[:, 0]
|
||||||
|
# Add conditioning using the standalone function
|
||||||
|
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
||||||
|
latent_model_input=latent_model_input.to(
|
||||||
|
self.device_torch, self.torch_dtype
|
||||||
|
),
|
||||||
|
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
||||||
|
vae=self.vae,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||||
|
|
||||||
with torch.no_grad():
|
# make the noise mask
|
||||||
frames = batch.tensor
|
if noise_mask is None:
|
||||||
if len(frames.shape) == 4:
|
noise_mask = torch.ones(
|
||||||
first_frames = frames
|
conditioned_latent.shape,
|
||||||
elif len(frames.shape) == 5:
|
dtype=conditioned_latent.dtype,
|
||||||
first_frames = frames[:, 0]
|
device=conditioned_latent.device,
|
||||||
# Add conditioning using the standalone function
|
)
|
||||||
conditioned_latent, noise_mask = add_first_frame_conditioning_v22(
|
# todo write this better
|
||||||
latent_model_input=latent_model_input.to(
|
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
||||||
self.device_torch, self.torch_dtype
|
out_t_chunks = []
|
||||||
),
|
for t in t_chunks:
|
||||||
first_frame=first_frames.to(self.device_torch, self.torch_dtype),
|
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||||
vae=self.vae,
|
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||||
)
|
# batch_size, seq_len
|
||||||
else:
|
temp_ts = temp_ts.unsqueeze(0)
|
||||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
out_t_chunks.append(temp_ts)
|
||||||
|
timestep = torch.cat(out_t_chunks, dim=0)
|
||||||
# make the noise mask
|
|
||||||
if noise_mask is None:
|
|
||||||
noise_mask = torch.ones(
|
|
||||||
conditioned_latent.shape,
|
|
||||||
dtype=conditioned_latent.dtype,
|
|
||||||
device=conditioned_latent.device,
|
|
||||||
)
|
|
||||||
# todo write this better
|
|
||||||
t_chunks = torch.chunk(timestep, timestep.shape[0])
|
|
||||||
out_t_chunks = []
|
|
||||||
for t in t_chunks:
|
|
||||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
|
||||||
temp_ts = (noise_mask[0][0][:, ::2, ::2] * t).flatten()
|
|
||||||
# batch_size, seq_len
|
|
||||||
temp_ts = temp_ts.unsqueeze(0)
|
|
||||||
out_t_chunks.append(temp_ts)
|
|
||||||
timestep = torch.cat(out_t_chunks, dim=0)
|
|
||||||
|
|
||||||
noise_pred = self.model(
|
noise_pred = self.model(
|
||||||
hidden_states=conditioned_latent,
|
hidden_states=conditioned_latent,
|
||||||
|
|||||||
@@ -880,6 +880,8 @@ class DatasetConfig:
|
|||||||
|
|
||||||
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
||||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||||
|
|
||||||
|
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
|
|||||||
@@ -301,3 +301,10 @@ class DataLoaderBatchDTO:
|
|||||||
del self.control_tensor
|
del self.control_tensor
|
||||||
for file_item in self.file_items:
|
for file_item in self.file_items:
|
||||||
file_item.cleanup()
|
file_item.cleanup()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_config(self) -> 'DatasetConfig':
|
||||||
|
if len(self.file_items) > 0:
|
||||||
|
return self.file_items[0].dataset_config
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|||||||
@@ -524,6 +524,16 @@ export default function SimpleJob({
|
|||||||
checked={dataset.is_reg || false}
|
checked={dataset.is_reg || false}
|
||||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||||
/>
|
/>
|
||||||
|
{
|
||||||
|
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||||
|
<Checkbox
|
||||||
|
label="Do I2V"
|
||||||
|
checked={dataset.do_i2v || false}
|
||||||
|
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||||
|
docKey="datasets.do_i2v"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
|
|||||||
controls: [],
|
controls: [],
|
||||||
shrink_video_to_frames: true,
|
shrink_video_to_frames: true,
|
||||||
num_frames: 1,
|
num_frames: 1,
|
||||||
|
do_i2v: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultJobConfig: JobConfig = {
|
export const defaultJobConfig: JobConfig = {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
|||||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||||
|
|
||||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||||
type ModelGroup = 'image' | 'video';
|
type ModelGroup = 'image' | 'video';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
|
|||||||
import { CircleHelp } from 'lucide-react';
|
import { CircleHelp } from 'lucide-react';
|
||||||
import { getDoc } from '@/docs';
|
import { getDoc } from '@/docs';
|
||||||
import { openDoc } from '@/components/DocModal';
|
import { openDoc } from '@/components/DocModal';
|
||||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
|
||||||
|
|
||||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||||
|
|
||||||
@@ -16,7 +16,8 @@ const inputClasses =
|
|||||||
|
|
||||||
export interface InputProps {
|
export interface InputProps {
|
||||||
label?: string;
|
label?: string;
|
||||||
docKey?: string;
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
className?: string;
|
className?: string;
|
||||||
placeholder?: string;
|
placeholder?: string;
|
||||||
required?: boolean;
|
required?: boolean;
|
||||||
@@ -29,37 +30,39 @@ export interface TextInputProps extends InputProps {
|
|||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
|
||||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
|
const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
return (
|
if (!doc && docKey) {
|
||||||
<div className={classNames(className)}>
|
doc = getDoc(docKey);
|
||||||
{label && (
|
}
|
||||||
<label className={labelClasses}>
|
return (
|
||||||
{label}{' '}
|
<div className={classNames(className)}>
|
||||||
{doc && (
|
{label && (
|
||||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
<label className={labelClasses}>
|
||||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
{label}{' '}
|
||||||
</div>
|
{doc && (
|
||||||
)}
|
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||||
</label>
|
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||||
)}
|
</div>
|
||||||
<input
|
)}
|
||||||
ref={ref}
|
</label>
|
||||||
type={type}
|
)}
|
||||||
value={value}
|
<input
|
||||||
onChange={e => {
|
ref={ref}
|
||||||
if (!disabled) onChange(e.target.value);
|
type={type}
|
||||||
}}
|
value={value}
|
||||||
className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`}
|
onChange={e => {
|
||||||
placeholder={placeholder}
|
if (!disabled) onChange(e.target.value);
|
||||||
required={required}
|
}}
|
||||||
disabled={disabled}
|
className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`}
|
||||||
/>
|
placeholder={placeholder}
|
||||||
</div>
|
required={required}
|
||||||
);
|
disabled={disabled}
|
||||||
},
|
/>
|
||||||
);
|
</div>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
// 👇 Helpful for debugging
|
// 👇 Helpful for debugging
|
||||||
TextInput.displayName = 'TextInput';
|
TextInput.displayName = 'TextInput';
|
||||||
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
|
|||||||
|
|
||||||
export const NumberInput = (props: NumberInputProps) => {
|
export const NumberInput = (props: NumberInputProps) => {
|
||||||
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
|
|
||||||
// Add controlled internal state to properly handle partial inputs
|
// Add controlled internal state to properly handle partial inputs
|
||||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||||
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
|
|||||||
|
|
||||||
export const SelectInput = (props: SelectInputProps) => {
|
export const SelectInput = (props: SelectInputProps) => {
|
||||||
const { label, value, onChange, options, docKey = null } = props;
|
const { label, value, onChange, options, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
let selectedOption: SelectOption | undefined;
|
let selectedOption: SelectOption | undefined;
|
||||||
if (options && options.length > 0) {
|
if (options && options.length > 0) {
|
||||||
// see if grouped options
|
// see if grouped options
|
||||||
if ('options' in options[0]) {
|
if ('options' in options[0]) {
|
||||||
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
selectedOption = (options as GroupedSelectOption[])
|
||||||
|
.flatMap(group => group.options)
|
||||||
|
.find(opt => opt.value === value);
|
||||||
} else {
|
} else {
|
||||||
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||||
}
|
}
|
||||||
@@ -196,10 +207,17 @@ export interface CheckboxProps {
|
|||||||
className?: string;
|
className?: string;
|
||||||
required?: boolean;
|
required?: boolean;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const Checkbox = (props: CheckboxProps) => {
|
export const Checkbox = (props: CheckboxProps) => {
|
||||||
const { label, checked, onChange, required, disabled } = props;
|
const { label, checked, onChange, required, disabled } = props;
|
||||||
|
let { doc } = props;
|
||||||
|
if (!doc && props.docKey) {
|
||||||
|
doc = getDoc(props.docKey);
|
||||||
|
}
|
||||||
|
|
||||||
const id = React.useId();
|
const id = React.useId();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -227,15 +245,22 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
/>
|
/>
|
||||||
</button>
|
</button>
|
||||||
{label && (
|
{label && (
|
||||||
<label
|
<>
|
||||||
htmlFor={id}
|
<label
|
||||||
className={classNames(
|
htmlFor={id}
|
||||||
'text-sm font-medium cursor-pointer select-none',
|
className={classNames(
|
||||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
'text-sm font-medium cursor-pointer select-none',
|
||||||
|
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</label>
|
||||||
|
{doc && (
|
||||||
|
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||||
|
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
>
|
</>
|
||||||
{label}
|
|
||||||
</label>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
interface FormGroupProps {
|
interface FormGroupProps {
|
||||||
label?: string;
|
label?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
docKey?: string;
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
|
export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||||
const doc = getDoc(docKey);
|
const { label, className, children, docKey = null } = props;
|
||||||
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<div className={classNames(className)}>
|
<div className={classNames(className)}>
|
||||||
{label && (
|
{label && (
|
||||||
|
|||||||
@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
'datasets.do_i2v': {
|
||||||
|
title: 'Do I2V',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||||
|
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||||
|
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ export interface DatasetConfig {
|
|||||||
control_path: string | null;
|
control_path: string | null;
|
||||||
num_frames: number;
|
num_frames: number;
|
||||||
shrink_video_to_frames: boolean;
|
shrink_video_to_frames: boolean;
|
||||||
|
do_i2v: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface EMAConfig {
|
export interface EMAConfig {
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.3.17"
|
VERSION = "0.3.18"
|
||||||
Reference in New Issue
Block a user