Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -389,22 +389,40 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
/>
</FormGroup>
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
<FormGroup label="Unload Text Encoder" className="pt-2">
<div className="grid grid-cols-2 gap-2">
<Checkbox
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
/>
</div>
{jobConfig.config.process[0].train.ema_config?.use_ema && (
<NumberInput
label="EMA Decay"
className="pt-2"
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
placeholder="eg. 0.99"
min={0}
/>
)}
<FormGroup label="Text Encoder Optimizations" className="pt-2">
<Checkbox
label="Unload TE"
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
docKey={'train.unload_text_encoder'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
if (value) {
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
}
}}
/>
<Checkbox
label="Cache Text Embeddings"
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
docKey={'train.cache_text_embeddings'}
onChange={(value) => {
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
if (value) {
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
}
}}
/>
</FormGroup>
</div>
<div>
@@ -416,21 +434,27 @@ export default function SimpleJob({
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
/>
</FormGroup>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
{jobConfig.config.process[0].train.diff_output_preservation && (
<>
<NumberInput
label="DOP Loss Multiplier"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
onChange={value =>
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
}
placeholder="eg. 1.0"
min={0}
/>
<TextInput
label="DOP Preservation Class"
className="pt-2"
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
placeholder="eg. woman"
/>
</>
)}
</div>
</div>
</Card>
@@ -524,16 +548,14 @@ export default function SimpleJob({
checked={dataset.is_reg || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
/>
{
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)
}
{modelArch?.additionalSections?.includes('datasets.do_i2v') && (
<Checkbox
label="Do I2V"
checked={dataset.do_i2v || false}
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
docKey="datasets.do_i2v"
/>
)}
</FormGroup>
</div>
<div>

View File

@@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = {
weight_decay: 1e-4,
},
unload_text_encoder: false,
cache_text_embeddings: false,
lr: 0.0001,
ema_config: {
use_ema: false,

View File

@@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'gpuids': {
gpuids: {
title: 'GPU ID',
description: (
<>
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU.
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
However, you can start multiple jobs in parallel, each using a different GPU.
</>
),
},
@@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Trigger Word',
description: (
<>
Optional: This will be the word or token used to trigger your concept or character.
<br />
<br />
When using a trigger word,
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
<br />
<br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
Optional: This will be the word or token used to trigger your concept or character.
<br />
<br />
When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the
beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you
want to have variable trigger words in your captions to put it in different spots, you can use the{' '}
<code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger
word.
<br />
<br />
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger
word manually or use the
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
</>
),
},
@@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Name or Path',
description: (
<>
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
path to an all in one safetensors checkpoint here.
</>
),
},
@@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Control Dataset',
description: (
<>
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
These images are fed as control/input images during training.
The control dataset needs to have files that match the filenames of your training dataset. They should be
matching file pairs. These images are fed as control/input images during training.
</>
),
},
@@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Number of Frames',
description: (
<>
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
<br/>
<br/>
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
So you would want all of your videos trimmed to around 5 seconds for best results.
<br/>
<br/>
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1
for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the
dataset.
<br />
<br />
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames
will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best
results.
<br />
<br />
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long,
will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second
video appear very fast.
</>
),
},
@@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = {
title: 'Do I2V',
description: (
<>
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this
dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and
used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset.
</>
),
},
'train.unload_text_encoder': {
title: 'Unload Text Encoder',
description: (
<>
Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the
GPU. Captions in for the dataset will be ignored
</>
),
},
'train.cache_text_embeddings': {
title: 'Cache Text Embeddings',
description: (
<>
<small>(experimental)</small>
<br />
Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The
text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt
such as trigger words, caption dropout, etc.
</>
),
},

View File

@@ -110,6 +110,7 @@ export interface TrainConfig {
ema_config?: EMAConfig;
dtype: string;
unload_text_encoder: boolean;
cache_text_embeddings: boolean;
optimizer_params: {
weight_decay: number;
};