mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add ability to designate a dataset as i2v or t2v for models that support it
This commit is contained in:
@@ -524,6 +524,16 @@ export default function SimpleJob({
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
{
|
||||
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)
|
||||
}
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
|
||||
@@ -15,6 +15,7 @@ export const defaultDatasetConfig: DatasetConfig = {
|
||||
controls: [],
|
||||
shrink_video_to_frames: true,
|
||||
num_frames: 1,
|
||||
do_i2v: true,
|
||||
};
|
||||
|
||||
export const defaultJobConfig: JobConfig = {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
@@ -201,7 +201,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'datasets.do_i2v'],
|
||||
},
|
||||
{
|
||||
name: 'lumina2',
|
||||
|
||||
@@ -6,7 +6,7 @@ import dynamic from 'next/dynamic';
|
||||
import { CircleHelp } from 'lucide-react';
|
||||
import { getDoc } from '@/docs';
|
||||
import { openDoc } from '@/components/DocModal';
|
||||
import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
import { ConfigDoc, GroupedSelectOption, SelectOption } from '@/types';
|
||||
|
||||
const Select = dynamic(() => import('react-select'), { ssr: false });
|
||||
|
||||
@@ -16,7 +16,8 @@ const inputClasses =
|
||||
|
||||
export interface InputProps {
|
||||
label?: string;
|
||||
docKey?: string;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
className?: string;
|
||||
placeholder?: string;
|
||||
required?: boolean;
|
||||
@@ -29,37 +30,39 @@ export interface TextInputProps extends InputProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>(
|
||||
({ label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null }, ref) => {
|
||||
const doc = getDoc(docKey);
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<input
|
||||
ref={ref}
|
||||
type={type}
|
||||
value={value}
|
||||
onChange={e => {
|
||||
if (!disabled) onChange(e.target.value);
|
||||
}}
|
||||
className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
export const TextInput = forwardRef<HTMLInputElement, TextInputProps>((props: TextInputProps, ref) => {
|
||||
const { label, value, onChange, placeholder, required, disabled, type = 'text', className, docKey = null } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<input
|
||||
ref={ref}
|
||||
type={type}
|
||||
value={value}
|
||||
onChange={e => {
|
||||
if (!disabled) onChange(e.target.value);
|
||||
}}
|
||||
className={`${inputClasses} ${disabled ? 'opacity-30 cursor-not-allowed' : ''}`}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
// 👇 Helpful for debugging
|
||||
TextInput.displayName = 'TextInput';
|
||||
@@ -73,7 +76,10 @@ export interface NumberInputProps extends InputProps {
|
||||
|
||||
export const NumberInput = (props: NumberInputProps) => {
|
||||
const { label, value, onChange, placeholder, required, min, max, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
|
||||
// Add controlled internal state to properly handle partial inputs
|
||||
const [inputValue, setInputValue] = React.useState<string | number>(value ?? '');
|
||||
@@ -147,12 +153,17 @@ export interface SelectInputProps extends InputProps {
|
||||
|
||||
export const SelectInput = (props: SelectInputProps) => {
|
||||
const { label, value, onChange, options, docKey = null } = props;
|
||||
const doc = getDoc(docKey);
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
let selectedOption: SelectOption | undefined;
|
||||
if (options && options.length > 0) {
|
||||
// see if grouped options
|
||||
if ('options' in options[0]) {
|
||||
selectedOption = (options as GroupedSelectOption[]).flatMap(group => group.options).find(opt => opt.value === value);
|
||||
selectedOption = (options as GroupedSelectOption[])
|
||||
.flatMap(group => group.options)
|
||||
.find(opt => opt.value === value);
|
||||
} else {
|
||||
selectedOption = (options as SelectOption[]).find(opt => opt.value === value);
|
||||
}
|
||||
@@ -196,10 +207,17 @@ export interface CheckboxProps {
|
||||
className?: string;
|
||||
required?: boolean;
|
||||
disabled?: boolean;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
}
|
||||
|
||||
export const Checkbox = (props: CheckboxProps) => {
|
||||
const { label, checked, onChange, required, disabled } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && props.docKey) {
|
||||
doc = getDoc(props.docKey);
|
||||
}
|
||||
|
||||
const id = React.useId();
|
||||
|
||||
return (
|
||||
@@ -227,15 +245,22 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
/>
|
||||
</button>
|
||||
{label && (
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'text-sm font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||
<>
|
||||
<label
|
||||
htmlFor={id}
|
||||
className={classNames(
|
||||
'text-sm font-medium cursor-pointer select-none',
|
||||
disabled ? 'text-gray-500' : 'text-gray-300',
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
<CircleHelp className="inline-block w-4 h-4 cursor-pointer" />
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
@@ -244,12 +269,17 @@ export const Checkbox = (props: CheckboxProps) => {
|
||||
interface FormGroupProps {
|
||||
label?: string;
|
||||
className?: string;
|
||||
docKey?: string;
|
||||
docKey?: string | null;
|
||||
doc?: ConfigDoc | null;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export const FormGroup: React.FC<FormGroupProps> = ({ label, className, children, docKey = null }) => {
|
||||
const doc = getDoc(docKey);
|
||||
export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
const { label, className, children, docKey = null } = props;
|
||||
let { doc } = props;
|
||||
if (!doc && docKey) {
|
||||
doc = getDoc(docKey);
|
||||
}
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
|
||||
@@ -74,6 +74,16 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'datasets.do_i2v': {
|
||||
title: 'Do I2V',
|
||||
description: (
|
||||
<>
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -86,6 +86,7 @@ export interface DatasetConfig {
|
||||
control_path: string | null;
|
||||
num_frames: number;
|
||||
shrink_video_to_frames: boolean;
|
||||
do_i2v: boolean;
|
||||
}
|
||||
|
||||
export interface EMAConfig {
|
||||
|
||||
Reference in New Issue
Block a user