Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -168,7 +168,9 @@ class QwenImageModel(BaseModel):
text_encoder = [pipe.text_encoder]
tokenizer = [pipe.tokenizer]
pipe.transformer = pipe.transformer.to(self.device_torch)
# leave it on cpu for now
if not self.low_vram:
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
# just to make sure everything is on the right device and dtype
@@ -210,6 +212,7 @@ class QwenImageModel(BaseModel):
generator: torch.Generator,
extra: dict,
):
self.model.to(self.device_torch, dtype=self.torch_dtype)
control_img = None
if gen_config.ctrl_img is not None:
raise NotImplementedError(

View File

@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.config_modules import GenerateImageConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
from toolkit.util.wavelet_loss import wavelet_loss
import torch.nn.functional as F
from toolkit.unloader import unload_text_encoder
def flush():
@@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess):
def before_model_load(self):
pass
def cache_sample_prompts(self):
if self.train_config.disable_sampling:
return
if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0:
# cache all the samples
self.sd.sample_prompts_cache = []
sample_folder = os.path.join(self.save_root, 'samples')
output_path = os.path.join(sample_folder, 'test.jpg')
for i in range(len(self.sample_config.prompts)):
sample_item = self.sample_config.samples[i]
prompt = self.sample_config.prompts[i]
# needed so we can autoparse the prompt to handle flags
gen_img_config = GenerateImageConfig(
prompt=prompt, # it will autoparse the prompt
negative_prompt=sample_item.neg,
output_path=output_path,
)
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
self.sd.sample_prompts_cache.append({
'conditional': positive,
'unconditional': negative
})
def before_dataset_load(self):
self.assistant_adapter = None
@@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess):
def hook_before_train_loop(self):
super().hook_before_train_loop()
if self.is_caching_text_embeddings:
# make sure model is on cpu for this part so we don't oom.
self.sd.unet.to('cpu')
# cache unconditional embeds (blank prompt)
with torch.no_grad():
@@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool = [self.train_config.negative_prompt]
# handle unload text encoder
if self.train_config.unload_text_encoder:
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.no_grad():
if self.train_config.train_text_encoder:
raise ValueError("Cannot unload text encoder if training text encoder")
# cache embeddings
print_acc("\n***** UNLOADING TEXT ENCODER *****")
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
if self.is_caching_text_embeddings:
print_acc("Embeddings cached to disk. We dont need the text encoder anymore")
else:
print_acc("This will train only with a blank prompt or trigger word, if set")
print_acc("If this is not what you want, remove the unload_text_encoder flag")
print_acc("***********************************")
print_acc("")
self.sd.text_encoder_to(self.device_torch)
@@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess):
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
if self.train_config.diff_output_preservation:
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
self.cache_sample_prompts()
# move back to cpu
self.sd.text_encoder_to('cpu')
# unload the text encoder
if self.is_caching_text_embeddings:
unload_text_encoder(self.sd)
else:
# todo once every model is tested to work, unload properly. Though, this will all be merged into one thing.
# keep legacy usage for now.
self.sd.text_encoder_to("cpu")
flush()
if self.train_config.diffusion_feature_extractor_path is not None:
@@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess):
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
if batch.prompt_embeds is not None:
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
else:
embeds_to_use = self.sd.encode_prompt(
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this
# self.network.multiplier = 0.0
@@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('encode_prompt'):
unconditional_embeds = None
if self.train_config.unload_text_encoder:
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
with torch.set_grad_enabled(False):
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
if batch.prompt_embeds is not None:
# use the cached embeds
conditional_embeds = batch.prompt_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
else:
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
if self.cached_trigger_embeds is not None and not is_reg:
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
self.device_torch, dtype=dtype
)
conditional_embeds = concat_prompt_embeds(
[embeds_to_use] * noisy_latents.shape[0]
)
if self.train_config.do_cfg:
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
self.device_torch, dtype=dtype