mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -168,7 +168,9 @@ class QwenImageModel(BaseModel):
|
||||
text_encoder = [pipe.text_encoder]
|
||||
tokenizer = [pipe.tokenizer]
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
# leave it on cpu for now
|
||||
if not self.low_vram:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
@@ -210,6 +212,7 @@ class QwenImageModel(BaseModel):
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||
control_img = None
|
||||
if gen_config.ctrl_img is not None:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.config_modules import GenerateImageConfig
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
|
||||
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.unloader import unload_text_encoder
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def cache_sample_prompts(self):
|
||||
if self.train_config.disable_sampling:
|
||||
return
|
||||
if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0:
|
||||
# cache all the samples
|
||||
self.sd.sample_prompts_cache = []
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
output_path = os.path.join(sample_folder, 'test.jpg')
|
||||
for i in range(len(self.sample_config.prompts)):
|
||||
sample_item = self.sample_config.samples[i]
|
||||
prompt = self.sample_config.prompts[i]
|
||||
|
||||
# needed so we can autoparse the prompt to handle flags
|
||||
gen_img_config = GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
)
|
||||
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
|
||||
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
|
||||
|
||||
self.sd.sample_prompts_cache.append({
|
||||
'conditional': positive,
|
||||
'unconditional': negative
|
||||
})
|
||||
|
||||
|
||||
def before_dataset_load(self):
|
||||
self.assistant_adapter = None
|
||||
@@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.is_caching_text_embeddings:
|
||||
# make sure model is on cpu for this part so we don't oom.
|
||||
self.sd.unet.to('cpu')
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
@@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool = [self.train_config.negative_prompt]
|
||||
|
||||
# handle unload text encoder
|
||||
if self.train_config.unload_text_encoder:
|
||||
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
|
||||
with torch.no_grad():
|
||||
if self.train_config.train_text_encoder:
|
||||
raise ValueError("Cannot unload text encoder if training text encoder")
|
||||
# cache embeddings
|
||||
|
||||
print_acc("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
if self.is_caching_text_embeddings:
|
||||
print_acc("Embeddings cached to disk. We dont need the text encoder anymore")
|
||||
else:
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print_acc("***********************************")
|
||||
print_acc("")
|
||||
self.sd.text_encoder_to(self.device_torch)
|
||||
@@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
||||
if self.train_config.diff_output_preservation:
|
||||
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
|
||||
|
||||
self.cache_sample_prompts()
|
||||
|
||||
# move back to cpu
|
||||
self.sd.text_encoder_to('cpu')
|
||||
# unload the text encoder
|
||||
if self.is_caching_text_embeddings:
|
||||
unload_text_encoder(self.sd)
|
||||
else:
|
||||
# todo once every model is tested to work, unload properly. Though, this will all be merged into one thing.
|
||||
# keep legacy usage for now.
|
||||
self.sd.text_encoder_to("cpu")
|
||||
flush()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
@@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt = prompt.replace(trigger, class_name)
|
||||
prompt_list[idx] = prompt
|
||||
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
if batch.prompt_embeds is not None:
|
||||
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
@@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
unconditional_embeds = None
|
||||
if self.train_config.unload_text_encoder:
|
||||
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
|
||||
with torch.set_grad_enabled(False):
|
||||
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
if self.cached_trigger_embeds is not None and not is_reg:
|
||||
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
|
||||
if batch.prompt_embeds is not None:
|
||||
# use the cached embeds
|
||||
conditional_embeds = batch.prompt_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
conditional_embeds = concat_prompt_embeds(
|
||||
[embeds_to_use] * noisy_latents.shape[0]
|
||||
)
|
||||
else:
|
||||
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
if self.cached_trigger_embeds is not None and not is_reg:
|
||||
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
conditional_embeds = concat_prompt_embeds(
|
||||
[embeds_to_use] * noisy_latents.shape[0]
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
|
||||
Reference in New Issue
Block a user