mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Add ability to designate a dataset as i2v or t2v for models that support it
This commit is contained in:
@@ -244,6 +244,7 @@ class Wan22Model(Wan21):
|
||||
conditioned_latent = latent_model_input
|
||||
noise_mask = None
|
||||
|
||||
if batch.dataset_config.do_i2v:
|
||||
with torch.no_grad():
|
||||
frames = batch.tensor
|
||||
if len(frames.shape) == 4:
|
||||
|
||||
@@ -881,6 +881,8 @@ class DatasetConfig:
|
||||
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||
|
||||
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
|
||||
@@ -301,3 +301,10 @@ class DataLoaderBatchDTO:
|
||||
del self.control_tensor
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
@property
|
||||
def dataset_config(self) -> 'DatasetConfig':
|
||||
if len(self.file_items) > 0:
|
||||
return self.file_items[0].dataset_config
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -524,6 +524,16 @@ export default function SimpleJob({
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
{
|
||||
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)
|
||||
}
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
|
||||
@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
controls: [],
|
||||
shrink_video_to_frames: true,
|
||||
num_frames: 1,
|
||||
do_i2v: true,
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
||||
},
|
||||
{
|
||||
name: 'lumina2',
|
||||
|
||||
@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
|
||||
import { CircleHelp } from 'lucide-react';
|
||||
import { getDoc } from '@/docs';
|
||||
import { openDoc } from '@/components/DocModal';
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
|
||||
|
||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||
|
||||
@@ -16,7 +16,8 @@ const inputClasses =
|
||||
|
||||
export interface InputProps {
|
||||
label?: string;
|
||||
docKey?: string;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
className?: string;
|
||||
placeholder?: string;
|
||||
required?: boolean;
|
||||
@@ -29,9 +30,12 @@ export interface TextInputProps extends InputProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
|
||||
const doc = getDoc(docKey);
|
||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
|
||||
const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
@@ -58,8 +62,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
// 👇 Helpful for debugging
|
||||
TextInput.displayName = 'TextInput';
|
||||
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
let selectedOption: SelectOption | undefined;
|
||||
if (options && options.length > 0) {
|
||||
// see if grouped options
|
||||
if ('options' in options[0]) {
|
||||
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
||||
selectedOption = (options as GroupedSelectOption[])
|
||||
.flatMap(group => group.options)
|
||||
.find(opt => opt.value === value);
|
||||
} else {
|
||||
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||
}
|
||||
@@ -196,10 +207,17 @@ export interface CheckboxProps {
|
||||
className?: string;
|
||||
required?: boolean;
|
||||
disabled?: boolean;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
}
|
||||
|
||||
export const Checkbox = (props: CheckboxProps) => {
|
||||
const { label, checked, onChange, required, disabled } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && props.docKey) {
|
||||
doc = getDoc(props.docKey);
|
||||
}
|
||||
|
||||
const id = React.useId();
|
||||
|
||||
return (
|
||||
@@ -227,6 +245,7 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
/>
|
||||
</button>
|
||||
{label && (
|
||||
<>
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
@@ -236,6 +255,12 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
interface FormGroupProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
docKey?: string;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
|
||||
const doc = getDoc(docKey);
|
||||
export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
const { label, className, children, docKey = null } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
|
||||
@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.do_i2v': {
|
||||
title: 'Do I2V',
|
||||
description: (
|
||||
<>
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -86,6 +86,7 @@ export interface DatasetConfig {
|
||||
control_path: string | null;
|
||||
num_frames: number;
|
||||
shrink_video_to_frames: boolean;
|
||||
do_i2v: boolean;
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.17"
|
||||
VERSION = "0.3.18"
|
||||
Reference in New Issue
Block a user