mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -389,22 +389,40 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
<FormGroup label="Unload Text Encoder" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
|
||||
/>
|
||||
</div>
|
||||
{jobConfig.config.process[0].train.ema_config?.use_ema && (
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
)}
|
||||
|
||||
<FormGroup label="Text Encoder Optimizations" className="pt-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
docKey={'train.cache_text_embeddings'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
@@ -416,21 +434,27 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
@@ -524,16 +548,14 @@ export default function SimpleJob({
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
{
|
||||
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)
|
||||
}
|
||||
{modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)}
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
|
||||
@@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
weight_decay: 1e-4,
|
||||
},
|
||||
unload_text_encoder: false,
|
||||
cache_text_embeddings: false,
|
||||
lr: 0.0001,
|
||||
ema_config: {
|
||||
use_ema: false,
|
||||
|
||||
@@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'gpuids': {
|
||||
gpuids: {
|
||||
title: 'GPU ID',
|
||||
description: (
|
||||
<>
|
||||
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
|
||||
However, you can start multiple jobs in parallel, each using a different GPU.
|
||||
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
|
||||
However, you can start multiple jobs in parallel, each using a different GPU.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Trigger Word',
|
||||
description: (
|
||||
<>
|
||||
Optional: This will be the word or token used to trigger your concept or character.
|
||||
<br />
|
||||
<br />
|
||||
When using a trigger word,
|
||||
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
|
||||
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
|
||||
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
|
||||
<br />
|
||||
<br />
|
||||
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
|
||||
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
|
||||
Optional: This will be the word or token used to trigger your concept or character.
|
||||
<br />
|
||||
<br />
|
||||
When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the
|
||||
beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you
|
||||
want to have variable trigger words in your captions to put it in different spots, you can use the{' '}
|
||||
<code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger
|
||||
word.
|
||||
<br />
|
||||
<br />
|
||||
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger
|
||||
word manually or use the
|
||||
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Name or Path',
|
||||
description: (
|
||||
<>
|
||||
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
|
||||
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
|
||||
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
|
||||
folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
|
||||
path to an all in one safetensors checkpoint here.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Control Dataset',
|
||||
description: (
|
||||
<>
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
|
||||
These images are fed as control/input images during training.
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be
|
||||
matching file pairs. These images are fed as control/input images during training.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Number of Frames',
|
||||
description: (
|
||||
<>
|
||||
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
|
||||
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
|
||||
<br/>
|
||||
<br/>
|
||||
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
|
||||
So you would want all of your videos trimmed to around 5 seconds for best results.
|
||||
<br/>
|
||||
<br/>
|
||||
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
|
||||
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
|
||||
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1
|
||||
for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the
|
||||
dataset.
|
||||
<br />
|
||||
<br />
|
||||
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames
|
||||
will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best
|
||||
results.
|
||||
<br />
|
||||
<br />
|
||||
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long,
|
||||
will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second
|
||||
video appear very fast.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Do I2V',
|
||||
description: (
|
||||
<>
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this
|
||||
dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and
|
||||
used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.unload_text_encoder': {
|
||||
title: 'Unload Text Encoder',
|
||||
description: (
|
||||
<>
|
||||
Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the
|
||||
GPU. Captions in for the dataset will be ignored
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.cache_text_embeddings': {
|
||||
title: 'Cache Text Embeddings',
|
||||
description: (
|
||||
<>
|
||||
<small>(experimental)</small>
|
||||
<br />
|
||||
Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The
|
||||
text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt
|
||||
such as trigger words, caption dropout, etc.
|
||||
</>
|
||||
),
|
||||
},
|
||||
|
||||
@@ -110,6 +110,7 @@ export interface TrainConfig {
|
||||
ema_config?: EMAConfig;
|
||||
dtype: string;
|
||||
unload_text_encoder: boolean;
|
||||
cache_text_embeddings: boolean;
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user