mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-19 20:04:04 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -168,7 +168,9 @@ class QwenImageModel(BaseModel):
|
||||
text_encoder = [pipe.text_encoder]
|
||||
tokenizer = [pipe.tokenizer]
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
# leave it on cpu for now
|
||||
if not self.low_vram:
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
# just to make sure everything is on the right device and dtype
|
||||
@@ -210,6 +212,7 @@ class QwenImageModel(BaseModel):
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
self.model.to(self.device_torch, dtype=self.torch_dtype)
|
||||
control_img = None
|
||||
if gen_config.ctrl_img is not None:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader, ConcatDataset
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.config_modules import GenerateImageConfig
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
|
||||
@@ -36,6 +36,7 @@ from toolkit.train_tools import precondition_model_outputs_flow_match
|
||||
from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe
|
||||
from toolkit.util.wavelet_loss import wavelet_loss
|
||||
import torch.nn.functional as F
|
||||
from toolkit.unloader import unload_text_encoder
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -108,6 +109,33 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def cache_sample_prompts(self):
|
||||
if self.train_config.disable_sampling:
|
||||
return
|
||||
if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0:
|
||||
# cache all the samples
|
||||
self.sd.sample_prompts_cache = []
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
output_path = os.path.join(sample_folder, 'test.jpg')
|
||||
for i in range(len(self.sample_config.prompts)):
|
||||
sample_item = self.sample_config.samples[i]
|
||||
prompt = self.sample_config.prompts[i]
|
||||
|
||||
# needed so we can autoparse the prompt to handle flags
|
||||
gen_img_config = GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
negative_prompt=sample_item.neg,
|
||||
output_path=output_path,
|
||||
)
|
||||
positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu')
|
||||
negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu')
|
||||
|
||||
self.sd.sample_prompts_cache.append({
|
||||
'conditional': positive,
|
||||
'unconditional': negative
|
||||
})
|
||||
|
||||
|
||||
def before_dataset_load(self):
|
||||
self.assistant_adapter = None
|
||||
@@ -143,6 +171,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
super().hook_before_train_loop()
|
||||
if self.is_caching_text_embeddings:
|
||||
# make sure model is on cpu for this part so we don't oom.
|
||||
self.sd.unet.to('cpu')
|
||||
|
||||
# cache unconditional embeds (blank prompt)
|
||||
with torch.no_grad():
|
||||
@@ -195,15 +226,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool = [self.train_config.negative_prompt]
|
||||
|
||||
# handle unload text encoder
|
||||
if self.train_config.unload_text_encoder:
|
||||
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
|
||||
with torch.no_grad():
|
||||
if self.train_config.train_text_encoder:
|
||||
raise ValueError("Cannot unload text encoder if training text encoder")
|
||||
# cache embeddings
|
||||
|
||||
print_acc("\n***** UNLOADING TEXT ENCODER *****")
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
if self.is_caching_text_embeddings:
|
||||
print_acc("Embeddings cached to disk. We dont need the text encoder anymore")
|
||||
else:
|
||||
print_acc("This will train only with a blank prompt or trigger word, if set")
|
||||
print_acc("If this is not what you want, remove the unload_text_encoder flag")
|
||||
print_acc("***********************************")
|
||||
print_acc("")
|
||||
self.sd.text_encoder_to(self.device_torch)
|
||||
@@ -212,9 +246,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word)
|
||||
if self.train_config.diff_output_preservation:
|
||||
self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class)
|
||||
|
||||
self.cache_sample_prompts()
|
||||
|
||||
# move back to cpu
|
||||
self.sd.text_encoder_to('cpu')
|
||||
# unload the text encoder
|
||||
if self.is_caching_text_embeddings:
|
||||
unload_text_encoder(self.sd)
|
||||
else:
|
||||
# todo once every model is tested to work, unload properly. Though, this will all be merged into one thing.
|
||||
# keep legacy usage for now.
|
||||
self.sd.text_encoder_to("cpu")
|
||||
flush()
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
@@ -923,11 +964,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt = prompt.replace(trigger, class_name)
|
||||
prompt_list[idx] = prompt
|
||||
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
if batch.prompt_embeds is not None:
|
||||
embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
@@ -1294,18 +1338,24 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
unconditional_embeds = None
|
||||
if self.train_config.unload_text_encoder:
|
||||
if self.train_config.unload_text_encoder or self.is_caching_text_embeddings:
|
||||
with torch.set_grad_enabled(False):
|
||||
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
if self.cached_trigger_embeds is not None and not is_reg:
|
||||
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
|
||||
if batch.prompt_embeds is not None:
|
||||
# use the cached embeds
|
||||
conditional_embeds = batch.prompt_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
conditional_embeds = concat_prompt_embeds(
|
||||
[embeds_to_use] * noisy_latents.shape[0]
|
||||
)
|
||||
else:
|
||||
embeds_to_use = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
if self.cached_trigger_embeds is not None and not is_reg:
|
||||
embeds_to_use = self.cached_trigger_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
)
|
||||
conditional_embeds = concat_prompt_embeds(
|
||||
[embeds_to_use] * noisy_latents.shape[0]
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_embeds = self.cached_blank_embeds.clone().detach().to(
|
||||
self.device_torch, dtype=dtype
|
||||
|
||||
@@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
self.dataset_configs: List[DatasetConfig] = []
|
||||
self.params = []
|
||||
|
||||
# add dataset text embedding cache to their config
|
||||
if self.train_config.cache_text_embeddings:
|
||||
for raw_dataset in raw_datasets:
|
||||
raw_dataset['cache_text_embeddings'] = True
|
||||
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
@@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
self.dataset_configs.append(dataset)
|
||||
|
||||
self.is_caching_text_embeddings = any(
|
||||
dataset.cache_text_embeddings for dataset in self.dataset_configs
|
||||
)
|
||||
|
||||
# cannot train trigger word if caching text embeddings
|
||||
if self.is_caching_text_embeddings and self.trigger_word is not None:
|
||||
raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.")
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
@@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
|
||||
require_grads=False # we ensure them later
|
||||
)
|
||||
|
||||
@@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
|
||||
require_grads=True # We check for grads when getting params
|
||||
)
|
||||
|
||||
@@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs)
|
||||
|
||||
do_profiler = self.get_conf('torch_profiler', False)
|
||||
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
|
||||
|
||||
@@ -482,6 +482,8 @@ class TrainConfig:
|
||||
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
|
||||
# will make training faster and use less vram
|
||||
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
|
||||
# will toggle all datasets to cache text embeddings
|
||||
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
|
||||
# for swapping which parameters are trained during training
|
||||
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
|
||||
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
|
||||
@@ -1189,6 +1191,7 @@ def validate_configs(
|
||||
train_config: TrainConfig,
|
||||
model_config: ModelConfig,
|
||||
save_config: SaveConfig,
|
||||
dataset_configs: List[DatasetConfig]
|
||||
):
|
||||
if model_config.is_flux:
|
||||
if save_config.save_format != 'diffusers':
|
||||
@@ -1200,3 +1203,18 @@ def validate_configs(
|
||||
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||
|
||||
# see if any datasets are caching text embeddings
|
||||
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
|
||||
if is_caching_text_embeddings:
|
||||
|
||||
# check if they are doing differential output preservation
|
||||
if train_config.diff_output_preservation:
|
||||
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
|
||||
|
||||
# make sure they are all cached
|
||||
for dataset in dataset_configs:
|
||||
if not dataset.cache_text_embeddings:
|
||||
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
|
||||
|
||||
|
||||
|
||||
@@ -558,6 +558,8 @@ class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin
|
||||
self.cache_latents_all_latents()
|
||||
if self.is_caching_clip_vision_to_disk:
|
||||
self.cache_clip_vision_to_disk()
|
||||
if self.is_caching_text_embeddings:
|
||||
self.cache_text_embeddings()
|
||||
if self.is_generating_controls:
|
||||
# always do this last
|
||||
self.setup_controls()
|
||||
|
||||
@@ -13,8 +13,8 @@ from toolkit import image_utils
|
||||
from toolkit.basic import get_quick_signature_string
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -32,6 +32,7 @@ def print_once(msg):
|
||||
|
||||
class FileItemDTO(
|
||||
LatentCachingFileItemDTOMixin,
|
||||
TextEmbeddingFileItemDTOMixin,
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
@@ -124,6 +125,7 @@ class FileItemDTO(
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_text_embedding()
|
||||
self.cleanup_control()
|
||||
self.cleanup_inpaint()
|
||||
self.cleanup_clip_image()
|
||||
@@ -136,6 +138,7 @@ class DataLoaderBatchDTO:
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -156,6 +159,7 @@ class DataLoaderBatchDTO:
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
if any([x.control_tensor is not None for x in self.file_items]):
|
||||
@@ -268,6 +272,22 @@ class DataLoaderBatchDTO:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
if any([x.prompt_embeds is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_prompt_embeds = None
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is not None:
|
||||
base_prompt_embeds = x.prompt_embeds
|
||||
break
|
||||
prompt_embeds_list = []
|
||||
for x in self.file_items:
|
||||
if x.prompt_embeds is None:
|
||||
prompt_embeds_list.append(base_prompt_embeds)
|
||||
else:
|
||||
prompt_embeds_list.append(x.prompt_embeds)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@@ -29,6 +29,7 @@ from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -301,7 +302,7 @@ class CaptionProcessingDTOMixin:
|
||||
self.extra_values: List[float] = dataset_config.extra_values
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
||||
if self.raw_caption is not None:
|
||||
# we already loaded it
|
||||
pass
|
||||
@@ -635,6 +636,9 @@ class ImageProcessingDTOMixin:
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
# handle get_prompt_embedding
|
||||
if self.is_text_embedding_cached:
|
||||
self.load_prompt_embedding()
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
@@ -1773,6 +1777,61 @@ class LatentCachingMixin:
|
||||
self.sd.restore_device_state()
|
||||
|
||||
|
||||
class TextEmbeddingFileItemDTOMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# if we have super, call it
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
self._text_embedding_path: Union[str, None] = None
|
||||
self.is_text_embedding_cached = False
|
||||
self.text_embedding_load_device = 'cpu'
|
||||
self.text_embedding_space_version = 'sd1'
|
||||
self.text_embedding_version = 1
|
||||
|
||||
def get_text_embedding_info_dict(self: 'FileItemDTO'):
|
||||
# make sure the caption is loaded here
|
||||
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
||||
if self.caption is None:
|
||||
self.load_caption()
|
||||
# throw error is [trigger] in caption as we cannot inject it while caching
|
||||
if '[trigger]' in self.caption:
|
||||
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
|
||||
item = OrderedDict([
|
||||
("caption", self.caption),
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
("text_embedding_version", self.text_embedding_version),
|
||||
])
|
||||
return item
|
||||
|
||||
def get_text_embedding_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._text_embedding_path is not None and not recalculate:
|
||||
return self._text_embedding_path
|
||||
else:
|
||||
# we store text embeddings in a folder in same path as image called _text_embedding_cache
|
||||
img_dir = os.path.dirname(self.path)
|
||||
te_dir = os.path.join(img_dir, '_t_e_cache')
|
||||
hash_dict = self.get_text_embedding_info_dict()
|
||||
filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
|
||||
# get base64 hash of md5 checksum of hash_dict
|
||||
hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
|
||||
hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
|
||||
hash_str = hash_str.replace('=', '')
|
||||
self._text_embedding_path = os.path.join(te_dir, f'{filename_no_ext}_{hash_str}.safetensors')
|
||||
|
||||
return self._text_embedding_path
|
||||
|
||||
def cleanup_text_embedding(self):
|
||||
if self.prompt_embeds is not None:
|
||||
# we are caching on disk, don't save in memory
|
||||
self.prompt_embeds = None
|
||||
|
||||
def load_prompt_embedding(self, device=None):
|
||||
if not self.is_text_embedding_cached:
|
||||
return
|
||||
if self.prompt_embeds is None:
|
||||
# load it from disk
|
||||
self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path())
|
||||
|
||||
class TextEmbeddingCachingMixin:
|
||||
def __init__(self: 'AiToolkitDataset', **kwargs):
|
||||
@@ -1780,90 +1839,36 @@ class TextEmbeddingCachingMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings
|
||||
if self.is_caching_text_embeddings:
|
||||
raise Exception("Error: caching text embeddings is a WIP and is not supported yet. Please set cache_text_embeddings to False in the dataset config")
|
||||
|
||||
def cache_text_embeddings(self: 'AiToolkitDataset'):
|
||||
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching text_embeddings for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
to_disk = self.is_caching_latents_to_disk
|
||||
to_memory = self.is_caching_latents_to_memory
|
||||
print_acc(" - Saving text embeddings to disk")
|
||||
# move sd items to cpu except for vae
|
||||
self.sd.set_device_state_preset('cache_latents')
|
||||
|
||||
did_move = False
|
||||
|
||||
# use tqdm to show progress
|
||||
i = 0
|
||||
for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
|
||||
# set latent space version
|
||||
if self.sd.model_config.latent_space_version is not None:
|
||||
file_item.latent_space_version = self.sd.model_config.latent_space_version
|
||||
elif self.sd.is_xl:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_v3:
|
||||
file_item.latent_space_version = 'sd3'
|
||||
elif self.sd.is_auraflow:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
elif self.sd.is_flux:
|
||||
file_item.latent_space_version = 'flux1'
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
for file_item in tqdm(self.file_list, desc='Caching text embeddings to disk'):
|
||||
file_item.text_embedding_space_version = self.sd.model_config.arch
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
latent_path = file_item.get_latent_path(recalculate=True)
|
||||
# check if it is saved to disk already
|
||||
if os.path.exists(latent_path):
|
||||
if to_memory:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
save_file(state_dict, latent_path, metadata=meta)
|
||||
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
text_embedding_path = file_item.get_text_embedding_path(recalculate=True)
|
||||
# only process if not saved to disk
|
||||
if not os.path.exists(text_embedding_path):
|
||||
# load if not loaded
|
||||
if not did_move:
|
||||
self.sd.set_device_state_preset('cache_text_encoder')
|
||||
did_move = True
|
||||
prompt_embeds: PromptEmbeds = self.sd.encode_prompt(file_item.caption)
|
||||
# save it
|
||||
prompt_embeds.save(text_embedding_path)
|
||||
del prompt_embeds
|
||||
file_item.is_text_embedding_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
# if did_move:
|
||||
# self.sd.restore_device_state()
|
||||
|
||||
|
||||
class CLIPCachingMixin:
|
||||
|
||||
@@ -168,6 +168,8 @@ class BaseModel:
|
||||
self._after_sample_img_hooks = []
|
||||
self._status_update_hooks = []
|
||||
self.is_transformer = False
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -484,19 +486,23 @@ class BaseModel:
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(
|
||||
gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
@@ -92,6 +92,56 @@ class PromptEmbeds:
|
||||
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||
return pe
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
Save the prompt embeds to a file.
|
||||
:param path: The path to save the prompt embeds.
|
||||
"""
|
||||
pe = self.clone()
|
||||
state_dict = {}
|
||||
if isinstance(pe.text_embeds, list) or isinstance(pe.text_embeds, tuple):
|
||||
for i, text_embed in enumerate(pe.text_embeds):
|
||||
state_dict[f"text_embed_{i}"] = text_embed.cpu()
|
||||
else:
|
||||
state_dict["text_embed"] = pe.text_embeds.cpu()
|
||||
|
||||
if pe.pooled_embeds is not None:
|
||||
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
|
||||
if pe.attention_mask is not None:
|
||||
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
save_file(state_dict, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'PromptEmbeds':
|
||||
"""
|
||||
Load the prompt embeds from a file.
|
||||
:param path: The path to load the prompt embeds from.
|
||||
:return: An instance of PromptEmbeds.
|
||||
"""
|
||||
state_dict = load_file(path, device='cpu')
|
||||
text_embeds = []
|
||||
pooled_embeds = None
|
||||
attention_mask = None
|
||||
for key in sorted(state_dict.keys()):
|
||||
if key.startswith("text_embed_"):
|
||||
text_embeds.append(state_dict[key])
|
||||
elif key == "text_embed":
|
||||
text_embeds.append(state_dict[key])
|
||||
elif key == "pooled_embed":
|
||||
pooled_embeds = state_dict[key]
|
||||
elif key == "attention_mask":
|
||||
attention_mask = state_dict[key]
|
||||
pe = cls(None)
|
||||
pe.text_embeds = text_embeds
|
||||
if len(text_embeds) == 1:
|
||||
pe.text_embeds = text_embeds[0]
|
||||
if pooled_embeds is not None:
|
||||
pe.pooled_embeds = pooled_embeds
|
||||
if attention_mask is not None:
|
||||
pe.attention_mask = attention_mask
|
||||
return pe
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
|
||||
@@ -209,6 +209,8 @@ class StableDiffusion:
|
||||
# todo update this based on the model
|
||||
self.is_transformer = False
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -1426,18 +1428,22 @@ class StableDiffusion:
|
||||
quad_count=4
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
if self.sample_prompts_cache is not None:
|
||||
conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype)
|
||||
else:
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
63
toolkit/unloader.py
Normal file
63
toolkit/unloader.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from toolkit.basic import flush
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.models.base_model import BaseModel
|
||||
|
||||
|
||||
class FakeTextEncoder(torch.nn.Module):
|
||||
def __init__(self, device, dtype):
|
||||
super().__init__()
|
||||
# register a dummy parameter to avoid errors in some cases
|
||||
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"This is a fake text encoder and should not be used for inference."
|
||||
)
|
||||
return None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def unload_text_encoder(model: "BaseModel"):
|
||||
# unload the text encoder in a way that will work with all models and will not throw errors
|
||||
# we need to make it appear as a text encoder module without actually having one so all
|
||||
# to functions and what not will work.
|
||||
|
||||
if model.text_encoder is not None:
|
||||
if isinstance(model.text_encoder, list):
|
||||
text_encoder_list = []
|
||||
pipe = model.pipeline
|
||||
|
||||
# the pipeline stores text encoders like text_encoder, text_encoder_2, text_encoder_3, etc.
|
||||
if hasattr(pipe, "text_encoder"):
|
||||
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
||||
text_encoder_list.append(te)
|
||||
pipe.text_encoder = te
|
||||
|
||||
i = 2
|
||||
while hasattr(pipe, f"text_encoder_{i}"):
|
||||
te = FakeTextEncoder(device=model.device_torch, dtype=model.torch_dtype)
|
||||
text_encoder_list.append(te)
|
||||
setattr(pipe, f"text_encoder_{i}", te)
|
||||
i += 1
|
||||
model.text_encoder = text_encoder_list
|
||||
else:
|
||||
# only has a single text encoder
|
||||
model.text_encoder = FakeTextEncoder()
|
||||
|
||||
flush()
|
||||
@@ -389,22 +389,40 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config.use_ema')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
<FormGroup label="Unload Text Encoder" className="pt-2">
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.unload_text_encoder')}
|
||||
/>
|
||||
</div>
|
||||
{jobConfig.config.process[0].train.ema_config?.use_ema && (
|
||||
<NumberInput
|
||||
label="EMA Decay"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.ema_config?.ema_decay as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.ema_config?.ema_decay')}
|
||||
placeholder="eg. 0.99"
|
||||
min={0}
|
||||
/>
|
||||
)}
|
||||
|
||||
<FormGroup label="Text Encoder Optimizations" className="pt-2">
|
||||
<Checkbox
|
||||
label="Unload TE"
|
||||
checked={jobConfig.config.process[0].train.unload_text_encoder || false}
|
||||
docKey={'train.unload_text_encoder'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.unload_text_encoder')
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.cache_text_embeddings');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Cache Text Embeddings"
|
||||
checked={jobConfig.config.process[0].train.cache_text_embeddings || false}
|
||||
docKey={'train.cache_text_embeddings'}
|
||||
onChange={(value) => {
|
||||
setJobConfig(value, 'config.process[0].train.cache_text_embeddings')
|
||||
if (value) {
|
||||
setJobConfig(false, 'config.process[0].train.unload_text_encoder')
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
@@ -416,21 +434,27 @@ export default function SimpleJob({
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
{jobConfig.config.process[0].train.diff_output_preservation && (
|
||||
<>
|
||||
<NumberInput
|
||||
label="DOP Loss Multiplier"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_multiplier as number}
|
||||
onChange={value =>
|
||||
setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')
|
||||
}
|
||||
placeholder="eg. 1.0"
|
||||
min={0}
|
||||
/>
|
||||
<TextInput
|
||||
label="DOP Preservation Class"
|
||||
className="pt-2"
|
||||
value={jobConfig.config.process[0].train.diff_output_preservation_class as string}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.diff_output_preservation_class')}
|
||||
placeholder="eg. woman"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
@@ -524,16 +548,14 @@ export default function SimpleJob({
|
||||
checked={dataset.is_reg || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].is_reg`)}
|
||||
/>
|
||||
{
|
||||
modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)
|
||||
}
|
||||
{modelArch?.additionalSections?.includes('datasets.do_i2v') && (
|
||||
<Checkbox
|
||||
label="Do I2V"
|
||||
checked={dataset.do_i2v || false}
|
||||
onChange={value => setJobConfig(value, `config.process[0].datasets[${i}].do_i2v`)}
|
||||
docKey="datasets.do_i2v"
|
||||
/>
|
||||
)}
|
||||
</FormGroup>
|
||||
</div>
|
||||
<div>
|
||||
|
||||
@@ -66,6 +66,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
weight_decay: 1e-4,
|
||||
},
|
||||
unload_text_encoder: false,
|
||||
cache_text_embeddings: false,
|
||||
lr: 0.0001,
|
||||
ema_config: {
|
||||
use_ema: false,
|
||||
|
||||
@@ -12,12 +12,12 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'gpuids': {
|
||||
gpuids: {
|
||||
title: 'GPU ID',
|
||||
description: (
|
||||
<>
|
||||
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
|
||||
However, you can start multiple jobs in parallel, each using a different GPU.
|
||||
This is the GPU that will be used for training. Only one GPU can be used per job at a time via the UI currently.
|
||||
However, you can start multiple jobs in parallel, each using a different GPU.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -25,17 +25,19 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Trigger Word',
|
||||
description: (
|
||||
<>
|
||||
Optional: This will be the word or token used to trigger your concept or character.
|
||||
<br />
|
||||
<br />
|
||||
When using a trigger word,
|
||||
If your captions do not contain the trigger word, it will be added automatically the beginning of the caption. If you do not have
|
||||
captions, the caption will become just the trigger word. If you want to have variable trigger words in your captions to put it in different spots,
|
||||
you can use the <code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger word.
|
||||
<br />
|
||||
<br />
|
||||
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger word manually or use the
|
||||
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
|
||||
Optional: This will be the word or token used to trigger your concept or character.
|
||||
<br />
|
||||
<br />
|
||||
When using a trigger word, If your captions do not contain the trigger word, it will be added automatically the
|
||||
beginning of the caption. If you do not have captions, the caption will become just the trigger word. If you
|
||||
want to have variable trigger words in your captions to put it in different spots, you can use the{' '}
|
||||
<code>{'[trigger]'}</code> placeholder in your captions. This will be automatically replaced with your trigger
|
||||
word.
|
||||
<br />
|
||||
<br />
|
||||
Trigger words will not automatically be added to your test prompts, so you will need to either add your trigger
|
||||
word manually or use the
|
||||
<code>{'[trigger]'}</code> placeholder in your test prompts as well.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -43,8 +45,9 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Name or Path',
|
||||
description: (
|
||||
<>
|
||||
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The folder needs to be in
|
||||
diffusers format for most models. For some models, such as SDXL and SD1, you can put the path to an all in one safetensors checkpoint here.
|
||||
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
|
||||
folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
|
||||
path to an all in one safetensors checkpoint here.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -52,8 +55,8 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Control Dataset',
|
||||
description: (
|
||||
<>
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be matching file pairs.
|
||||
These images are fed as control/input images during training.
|
||||
The control dataset needs to have files that match the filenames of your training dataset. They should be
|
||||
matching file pairs. These images are fed as control/input images during training.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -61,16 +64,19 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Number of Frames',
|
||||
description: (
|
||||
<>
|
||||
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1 for one frame.
|
||||
If your dataset is only videos, frames will be extracted evenly spaced from the videos in the dataset.
|
||||
<br/>
|
||||
<br/>
|
||||
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames will result in a 5 second video.
|
||||
So you would want all of your videos trimmed to around 5 seconds for best results.
|
||||
<br/>
|
||||
<br/>
|
||||
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long, will result in 81
|
||||
evenly spaced frames for each video making the 2 second video appear slow and the 90second video appear very fast.
|
||||
This sets the number of frames to shrink videos to for a video dataset. If this dataset is images, set this to 1
|
||||
for one frame. If your dataset is only videos, frames will be extracted evenly spaced from the videos in the
|
||||
dataset.
|
||||
<br />
|
||||
<br />
|
||||
It is best to trim your videos to the proper length before training. Wan is 16 frames a second. Doing 81 frames
|
||||
will result in a 5 second video. So you would want all of your videos trimmed to around 5 seconds for best
|
||||
results.
|
||||
<br />
|
||||
<br />
|
||||
Example: Setting this to 81 and having 2 videos in your dataset, one is 2 seconds and one is 90 seconds long,
|
||||
will result in 81 evenly spaced frames for each video making the 2 second video appear slow and the 90second
|
||||
video appear very fast.
|
||||
</>
|
||||
),
|
||||
},
|
||||
@@ -78,9 +84,30 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
title: 'Do I2V',
|
||||
description: (
|
||||
<>
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this dataset
|
||||
to be trained as an I2V dataset. This means that the first frame will be extracted from the video and used as the start image
|
||||
for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
For video models that can handle both I2V (Image to Video) and T2V (Text to Video), this option sets this
|
||||
dataset to be trained as an I2V dataset. This means that the first frame will be extracted from the video and
|
||||
used as the start image for the video. If this option is not set, the dataset will be treated as a T2V dataset.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.unload_text_encoder': {
|
||||
title: 'Unload Text Encoder',
|
||||
description: (
|
||||
<>
|
||||
Unloading text encoder will cache the trigger word and the sample prompts and unload the text encoder from the
|
||||
GPU. Captions in for the dataset will be ignored
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.cache_text_embeddings': {
|
||||
title: 'Cache Text Embeddings',
|
||||
description: (
|
||||
<>
|
||||
<small>(experimental)</small>
|
||||
<br />
|
||||
Caching text embeddings will process and cache all the text embeddings from the text encoder to the disk. The
|
||||
text encoder will be unloaded from the GPU. This does not work with things that dynamically change the prompt
|
||||
such as trigger words, caption dropout, etc.
|
||||
</>
|
||||
),
|
||||
},
|
||||
|
||||
@@ -110,6 +110,7 @@ export interface TrainConfig {
|
||||
ema_config?: EMAConfig;
|
||||
dtype: string;
|
||||
unload_text_encoder: boolean;
|
||||
cache_text_embeddings: boolean;
|
||||
optimizer_params: {
|
||||
weight_decay: number;
|
||||
};
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.3.18"
|
||||
VERSION = "0.4.0"
|
||||
Reference in New Issue
Block a user