mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Add ability to designate a dataset as i2v or t2v for models that support it
This commit is contained in:
@@ -244,6 +244,7 @@ class Wan22Model(Wan21):
|
|||||||
conditioned_latent = latent_model_input
|
conditioned_latent = latent_model_input
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
|
|
||||||
|
if batch.dataset_config.do_i2v:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
frames = batch.tensor
|
frames = batch.tensor
|
||||||
if len(frames.shape) == 4:
|
if len(frames.shape) == 4:
|
||||||
|
|||||||
@@ -881,6 +881,8 @@ class DatasetConfig:
|
|||||||
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
# if true, will use a fask method to get image sizes. This can result in errors. Do not use unless you know what you are doing
|
||||||
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
self.fast_image_size: bool = kwargs.get('fast_image_size', False)
|
||||||
|
|
||||||
|
self.do_i2v: bool = kwargs.get('do_i2v', True) # do image to video on models that are both t2i and i2v capable
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -301,3 +301,10 @@ class DataLoaderBatchDTO:
|
|||||||
del self.control_tensor
|
del self.control_tensor
|
||||||
for file_item in self.file_items:
|
for file_item in self.file_items:
|
||||||
file_item.cleanup()
|
file_item.cleanup()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_config(self) -> 'DatasetConfig':
|
||||||
|
if len(self.file_items) > 0:
|
||||||
|
return self.file_items[0].dataset_config
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|||||||
@@ -524,6 +524,16 @@ export default function SimpleJob({
|
|||||||
checked={dataset.is_reg || false}
|
checked={dataset.is_reg || false}
|
||||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||||
/>
|
/>
|
||||||
|
{
|
||||||
|
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||||
|
<Checkbox
|
||||||
|
label="Do I2V"
|
||||||
|
checked={dataset.do_i2v || false}
|
||||||
|
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||||
|
docKey="datasets.do_i2v"
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
</FormGroup>
|
</FormGroup>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
|
|||||||
controls: [],
|
controls: [],
|
||||||
shrink_video_to_frames: true,
|
shrink_video_to_frames: true,
|
||||||
num_frames: 1,
|
num_frames: 1,
|
||||||
|
do_i2v: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const defaultJobConfig: JobConfig = {
|
export const defaultJobConfig: JobConfig = {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
|||||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||||
|
|
||||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||||
type ModelGroup = 'image' | 'video';
|
type ModelGroup = 'image' | 'video';
|
||||||
|
|
||||||
export interface ModelArch {
|
export interface ModelArch {
|
||||||
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
|
|||||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||||
},
|
},
|
||||||
disableSections: ['network.conv'],
|
disableSections: ['network.conv'],
|
||||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'lumina2',
|
name: 'lumina2',
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
|
|||||||
import { CircleHelp } from 'lucide-react';
|
import { CircleHelp } from 'lucide-react';
|
||||||
import { getDoc } from '@/docs';
|
import { getDoc } from '@/docs';
|
||||||
import { openDoc } from '@/components/DocModal';
|
import { openDoc } from '@/components/DocModal';
|
||||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
|
||||||
|
|
||||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||||
|
|
||||||
@@ -16,7 +16,8 @@ const inputClasses =
|
|||||||
|
|
||||||
export interface InputProps {
|
export interface InputProps {
|
||||||
label?: string;
|
label?: string;
|
||||||
docKey?: string;
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
className?: string;
|
className?: string;
|
||||||
placeholder?: string;
|
placeholder?: string;
|
||||||
required?: boolean;
|
required?: boolean;
|
||||||
@@ -29,9 +30,12 @@ export interface TextInputProps extends InputProps {
|
|||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
|
||||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
|
const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<div className={classNames(className)}>
|
<div className={classNames(className)}>
|
||||||
{label && (
|
{label && (
|
||||||
@@ -58,8 +62,7 @@ export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
},
|
});
|
||||||
);
|
|
||||||
|
|
||||||
// 👇 Helpful for debugging
|
// 👇 Helpful for debugging
|
||||||
TextInput.displayName = 'TextInput';
|
TextInput.displayName = 'TextInput';
|
||||||
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
|
|||||||
|
|
||||||
export const NumberInput = (props: NumberInputProps) => {
|
export const NumberInput = (props: NumberInputProps) => {
|
||||||
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
|
|
||||||
// Add controlled internal state to properly handle partial inputs
|
// Add controlled internal state to properly handle partial inputs
|
||||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||||
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
|
|||||||
|
|
||||||
export const SelectInput = (props: SelectInputProps) => {
|
export const SelectInput = (props: SelectInputProps) => {
|
||||||
const { label, value, onChange, options, docKey = null } = props;
|
const { label, value, onChange, options, docKey = null } = props;
|
||||||
const doc = getDoc(docKey);
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
let selectedOption: SelectOption | undefined;
|
let selectedOption: SelectOption | undefined;
|
||||||
if (options && options.length > 0) {
|
if (options && options.length > 0) {
|
||||||
// see if grouped options
|
// see if grouped options
|
||||||
if ('options' in options[0]) {
|
if ('options' in options[0]) {
|
||||||
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
selectedOption = (options as GroupedSelectOption[])
|
||||||
|
.flatMap(group => group.options)
|
||||||
|
.find(opt => opt.value === value);
|
||||||
} else {
|
} else {
|
||||||
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||||
}
|
}
|
||||||
@@ -196,10 +207,17 @@ export interface CheckboxProps {
|
|||||||
className?: string;
|
className?: string;
|
||||||
required?: boolean;
|
required?: boolean;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const Checkbox = (props: CheckboxProps) => {
|
export const Checkbox = (props: CheckboxProps) => {
|
||||||
const { label, checked, onChange, required, disabled } = props;
|
const { label, checked, onChange, required, disabled } = props;
|
||||||
|
let { doc } = props;
|
||||||
|
if (!doc && props.docKey) {
|
||||||
|
doc = getDoc(props.docKey);
|
||||||
|
}
|
||||||
|
|
||||||
const id = React.useId();
|
const id = React.useId();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -227,6 +245,7 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
/>
|
/>
|
||||||
</button>
|
</button>
|
||||||
{label && (
|
{label && (
|
||||||
|
<>
|
||||||
<label
|
<label
|
||||||
htmlFor={id}
|
htmlFor={id}
|
||||||
className={classNames(
|
className={classNames(
|
||||||
@@ -236,6 +255,12 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</label>
|
</label>
|
||||||
|
{doc && (
|
||||||
|
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||||
|
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
|
|||||||
interface FormGroupProps {
|
interface FormGroupProps {
|
||||||
label?: string;
|
label?: string;
|
||||||
className?: string;
|
className?: string;
|
||||||
docKey?: string;
|
docKey?: string | null;
|
||||||
|
doc?: ConfigDoc | null;
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
|
export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||||
const doc = getDoc(docKey);
|
const { label, className, children, docKey = null } = props;
|
||||||
|
let { doc } = props;
|
||||||
|
if (!doc && docKey) {
|
||||||
|
doc = getDoc(docKey);
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<div className={classNames(className)}>
|
<div className={classNames(className)}>
|
||||||
{label && (
|
{label && (
|
||||||
|
|||||||
@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
|
|||||||
</>
|
</>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
'datasets.do_i2v': {
|
||||||
|
title: 'Do I2V',
|
||||||
|
description: (
|
||||||
|
<>
|
||||||
|
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||||
|
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||||
|
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||||
|
</>
|
||||||
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ export interface DatasetConfig {
|
|||||||
control_path: string | null;
|
control_path: string | null;
|
||||||
num_frames: number;
|
num_frames: number;
|
||||||
shrink_video_to_frames: boolean;
|
shrink_video_to_frames: boolean;
|
||||||
|
do_i2v: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface EMAConfig {
|
export interface EMAConfig {
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.3.17"
|
VERSION = "0.3.18"
|
||||||
Reference in New Issue
Block a user