From 73dedbf662ca604a3035daff2d2ba4635473b7bd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 14 Jan 2026 11:05:23 -0700 Subject: [PATCH] Do caching of latents, first frame and audio when caching latents for LTX2 --- .../diffusion_models/ltx2/ltx2.py | 149 ++++++++---- toolkit/data_transfer_object/data_loader.py | 222 ++++++++++++------ toolkit/dataloader_mixins.py | 70 ++++-- toolkit/models/base_model.py | 4 + toolkit/stable_diffusion_model.py | 4 + version.py | 2 +- 6 files changed, 324 insertions(+), 127 deletions(-) diff --git a/extensions_built_in/diffusion_models/ltx2/ltx2.py b/extensions_built_in/diffusion_models/ltx2/ltx2.py index 37f50567..6e50d3bc 100644 --- a/extensions_built_in/diffusion_models/ltx2/ltx2.py +++ b/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -98,6 +98,41 @@ def blank_log_image_function(self, *args, **kwargs): return +class ComboVae(torch.nn.Module): + """Combines video and audio VAEs for joint encoding and decoding.""" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + ) -> None: + super().__init__() + self.vae = vae + self.audio_vae = audio_vae + + @property + def device(self): + return self.vae.device + + @property + def dtype(self): + return self.vae.dtype + + def encode( + self, + *args, + **kwargs, + ): + return self.vae.encode(*args, **kwargs) + + def decode( + self, + *args, + **kwargs, + ): + return self.vae.decode(*args, **kwargs) + + class AudioProcessor(torch.nn.Module): """Converts audio waveforms to log-mel spectrograms with optional resampling.""" @@ -445,7 +480,7 @@ class LTX2Model(BaseModel): flush() # save it to the model class - self.vae = vae + self.vae = ComboVae(pipe.vae, pipe.audio_vae) self.text_encoder = text_encoder # list of text encoders self.tokenizer = tokenizer # list of tokenizers self.model = pipe.transformer @@ -467,10 +502,10 @@ class LTX2Model(BaseModel): if dtype is None: dtype = self.vae_torch_dtype - if self.vae.device == torch.device("cpu"): - self.vae.to(device) - self.vae.eval() - self.vae.requires_grad_(False) + if self.pipeline.vae.device == torch.device("cpu"): + self.pipeline.vae.to(device) + self.pipeline.vae.eval() + self.pipeline.vae.requires_grad_(False) image_list = [image.to(device, dtype=dtype) for image in image_list] @@ -489,7 +524,7 @@ class LTX2Model(BaseModel): # Stack to (B, C, T, H, W) images = torch.stack(norm_images) - latents = self.vae.encode(images).latent_dist.mode() + latents = self.pipeline.vae.encode(images).latent_dist.mode() # Normalize latents across the channel dimension [B, C, F, H, W] scaling_factor = 1.0 @@ -534,7 +569,7 @@ class LTX2Model(BaseModel): ): if self.model.device == torch.device("cpu"): self.model.to(self.device_torch) - + # handle control image if gen_config.ctrl_img is not None: # switch to image to video pipeline @@ -566,12 +601,14 @@ class LTX2Model(BaseModel): bd = self.get_bucket_divisibility() gen_config.height = (gen_config.height // bd) * bd gen_config.width = (gen_config.width // bd) * bd - + # handle control image if gen_config.ctrl_img is not None: control_img = Image.open(gen_config.ctrl_img).convert("RGB") # resize the control image - control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS) + control_img = control_img.resize( + (gen_config.width, gen_config.height), Image.LANCZOS + ) # add the control image to the extra dict extra["image"] = control_img @@ -642,7 +679,8 @@ class LTX2Model(BaseModel): img = video[0] return img - def encode_audio(self, batch: "DataLoaderBatchDTO"): + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} if self.pipeline.audio_vae.device == torch.device("cpu"): self.pipeline.audio_vae.to(self.device_torch) @@ -650,7 +688,7 @@ class LTX2Model(BaseModel): audio_num_frames = None # do them seperatly for now - for audio_data in batch.audio_data: + for audio_data in audio_data_list: waveform = audio_data["waveform"].to( device=self.device_torch, dtype=torch.float32 ) @@ -659,26 +697,30 @@ class LTX2Model(BaseModel): # Add batch dimension if needed: [channels, samples] -> [batch, channels, samples] if waveform.dim() == 2: waveform = waveform.unsqueeze(0) - + if waveform.shape[1] == 1: # make sure it is stereo waveform = waveform.repeat(1, 2, 1) - + # Convert waveform to mel spectrogram using AudioProcessor - mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) + mel_spectrogram = self.audio_processor.waveform_to_mel( + waveform, waveform_sample_rate=sample_rate + ) mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype) # Encode mel spectrogram to latents - latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode() + latents = self.pipeline.audio_vae.encode( + mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype) + ).latent_dist.mode() if audio_num_frames is None: - audio_num_frames = latents.shape[2] #(latents is [B, C, T, F]) + audio_num_frames = latents.shape[2] # (latents is [B, C, T, F]) packed_latents = self.pipeline._pack_audio_latents( latents, # patch_size=self.pipeline.transformer.config.audio_patch_size, # patch_size_t=self.pipeline.transformer.config.audio_patch_size_t, - ) # [B, L, C * M] + ) # [B, L, C * M] if output_tensor is None: output_tensor = packed_latents else: @@ -688,7 +730,7 @@ class LTX2Model(BaseModel): latents_mean = self.pipeline.audio_vae.latents_mean latents_std = self.pipeline.audio_vae.latents_std output_tensor = (output_tensor - latents_mean) / latents_std - return output_tensor, audio_num_frames + return output_tensor def get_noise_prediction( self, @@ -701,38 +743,57 @@ class LTX2Model(BaseModel): with torch.no_grad(): if self.model.device == torch.device("cpu"): self.model.to(self.device_torch) - + batch_size, C, latent_num_frames, latent_height, latent_width = ( latent_model_input.shape ) - + video_timestep = timestep.clone() - + # i2v from first frame if batch.dataset_config.do_i2v: - # videos come in (bs, num_frames, channels, height, width) - # images come in (bs, channels, height, width) - frames = batch.tensor - if len(frames.shape) == 4: - first_frames = frames - elif len(frames.shape) == 5: - first_frames = frames[:, 0] + # check to see if we had it cached + if batch.first_frame_latents is not None: + init_latents = batch.first_frame_latents.to( + self.device_torch, dtype=self.torch_dtype + ) else: - raise ValueError(f"Unknown frame shape {frames.shape}") - # first frame doesnt have time dim, add it back - init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype) + # extract the first frame and encode it + # videos come in (bs, num_frames, channels, height, width) + # images come in (bs, channels, height, width) + frames = batch.tensor + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + # first frame doesnt have time dim, add it back + init_latents = self.encode_images( + first_frames, device=self.device_torch, dtype=self.torch_dtype + ) + + # expand the latents to match video frames init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1) - mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) + mask_shape = ( + batch_size, + 1, + latent_num_frames, + latent_height, + latent_width, + ) # First condition is image latents and those should be kept clean. - conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype) + conditioning_mask = torch.zeros( + mask_shape, device=self.device_torch, dtype=self.torch_dtype + ) conditioning_mask[:, :, 0] = 1.0 - + # use conditioning mask to replace latents latent_model_input = ( latent_model_input * (1 - conditioning_mask) + init_latents * conditioning_mask ) - + # set video timestep video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) @@ -746,10 +807,18 @@ class LTX2Model(BaseModel): patch_size_t=self.pipeline.transformer_temporal_patch_size, ) - if batch.audio_tensor is not None: - # use audio from the batch if available - #(1, 190, 128) - raw_audio_latents, audio_num_frames = self.encode_audio(batch) + if batch.audio_latents is not None or batch.audio_tensor is not None: + if batch.audio_latents is not None: + # we have audio latents cached + raw_audio_latents = batch.audio_latents.to( + self.device_torch, dtype=self.torch_dtype + ) + else: + # we have audio waveforms to encode + # use audio from the batch if available + raw_audio_latents = self.encode_audio(batch.audio_data) + + audio_num_frames = raw_audio_latents.shape[1] # add the audio targets to the batch for loss calculation later audio_noise = torch.randn_like(raw_audio_latents) batch.audio_target = (audio_noise - raw_audio_latents).detach() @@ -758,7 +827,6 @@ class LTX2Model(BaseModel): audio_noise, timestep, ).to(self.device_torch, dtype=self.torch_dtype) - else: # no audio num_mel_bins = self.pipeline.audio_vae.config.mel_bins @@ -766,7 +834,6 @@ class LTX2Model(BaseModel): num_channels_latents_audio = ( self.pipeline.audio_vae.config.latent_channels ) - # audio latents are (1, 126, 128), audio_num_frames = 126 audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents( batch_size, num_channels_latents=num_channels_latents_audio, diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 5e76563e..7af8de01 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -1,24 +1,31 @@ import os -import weakref -from _weakref import ReferenceType from typing import TYPE_CHECKING, List, Union import cv2 import torch -import random from PIL import Image from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.basic import get_quick_signature_string -from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ - ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ - UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin +from toolkit.dataloader_mixins import ( + CaptionProcessingDTOMixin, + ImageProcessingDTOMixin, + LatentCachingFileItemDTOMixin, + ControlFileItemDTOMixin, + ArgBreakMixin, + PoiFileItemDTOMixin, + MaskFileItemDTOMixin, + AugmentationFileItemDTOMixin, + UnconditionalFileItemDTOMixin, + ClipImageFileItemDTOMixin, + InpaintControlFileItemDTOMixin, + TextEmbeddingFileItemDTOMixin, +) from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig - from toolkit.stable_diffusion_model import StableDiffusion printed_messages = [] @@ -45,54 +52,62 @@ class FileItemDTO( ArgBreakMixin, ): def __init__(self, *args, **kwargs): - self.path = kwargs.get('path', '') - self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.path = kwargs.get("path", "") + self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None) self.is_video = self.dataset_config.num_frames > 1 - size_database = kwargs.get('size_database', {}) - dataset_root = kwargs.get('dataset_root', None) - self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False) + size_database = kwargs.get("size_database", {}) + dataset_root = kwargs.get("dataset_root", None) + self.encode_control_in_text_embeddings = kwargs.get( + "encode_control_in_text_embeddings", False + ) if dataset_root is not None: # remove dataset root from path - file_key = self.path.replace(dataset_root, '') + file_key = self.path.replace(dataset_root, "") else: file_key = os.path.basename(self.path) - + file_signature = get_quick_signature_string(self.path) if file_signature is None: raise Exception("Error: Could not get file signature for {self.path}") - + use_db_entry = False if file_key in size_database: db_entry = size_database[file_key] - if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature: + if ( + db_entry is not None + and len(db_entry) >= 3 + and db_entry[2] == file_signature + ): use_db_entry = True - + if use_db_entry: w, h, _ = size_database[file_key] elif self.is_video: # Open the video file video = cv2.VideoCapture(self.path) - + # Check if video opened successfully if not video.isOpened(): raise Exception(f"Error: Could not open video file {self.path}") - + # Get width and height width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) w, h = width, height - + # Release the video capture object immediately video.release() size_database[file_key] = (width, height, file_signature) else: if self.dataset_config.fast_image_size: - # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. try: w, h = image_utils.get_image_size(self.path) except image_utils.UnknownImageFormat: - print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - f'This process is faster for png, jpeg') + print_once( + f"Warning: Some images in the dataset cannot be fast read. " + + f"This process is faster for png, jpeg" + ) img = exif_transpose(Image.open(self.path)) w, h = img.size else: @@ -101,21 +116,25 @@ class FileItemDTO( size_database[file_key] = (w, h, file_signature) self.width: int = w self.height: int = h - self.dataloader_transforms = kwargs.get('dataloader_transforms', None) + self.dataloader_transforms = kwargs.get("dataloader_transforms", None) super().__init__(*args, **kwargs) # self.caption_path: str = kwargs.get('caption_path', None) - self.raw_caption: str = kwargs.get('raw_caption', None) + self.raw_caption: str = kwargs.get("raw_caption", None) # we scale first, then crop - self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale)) - self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale)) + self.scale_to_width: int = kwargs.get( + "scale_to_width", int(self.width * self.dataset_config.scale) + ) + self.scale_to_height: int = kwargs.get( + "scale_to_height", int(self.height * self.dataset_config.scale) + ) # crop values are from scaled size - self.crop_x: int = kwargs.get('crop_x', 0) - self.crop_y: int = kwargs.get('crop_y', 0) - self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) - self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) - self.flip_x: bool = kwargs.get('flip_x', False) - self.flip_y: bool = kwargs.get('flip_x', False) + self.crop_x: int = kwargs.get("crop_x", 0) + self.crop_y: int = kwargs.get("crop_y", 0) + self.crop_width: int = kwargs.get("crop_width", self.scale_to_width) + self.crop_height: int = kwargs.get("crop_height", self.scale_to_height) + self.flip_x: bool = kwargs.get("flip_x", False) + self.flip_y: bool = kwargs.get("flip_x", False) self.augments: List[str] = self.dataset_config.augments self.loss_multiplier: float = self.dataset_config.loss_multiplier @@ -142,9 +161,8 @@ class FileItemDTO( class DataLoaderBatchDTO: def __init__(self, **kwargs): try: - self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) + self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None) is_latents_cached = self.file_items[0].is_latent_cached - is_text_embedding_cached = self.file_items[0].is_text_embedding_cached self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None @@ -156,22 +174,64 @@ class DataLoaderBatchDTO: self.unconditional_latents: Union[torch.Tensor, None] = None self.clip_image_embeds: Union[List[dict], None] = None self.clip_image_embeds_unconditional: Union[List[dict], None] = None - self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code - self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None - self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None + self.sigmas: Union[torch.Tensor, None] = ( + None # can be added elseware and passed along training code + ) + self.extra_values: Union[torch.Tensor, None] = ( + torch.tensor([x.extra_values for x in self.file_items]) + if len(self.file_items[0].extra_values) > 0 + else None + ) + self.audio_data: Union[List, None] = ( + [x.audio_data for x in self.file_items] + if self.file_items[0].audio_data is not None + else None + ) self.audio_tensor: Union[torch.Tensor, None] = None - + self.first_frame_latents: Union[torch.Tensor, None] = None + self.audio_latents: Union[torch.Tensor, None] = None + # just for holding noise and preds during training self.audio_target: Union[torch.Tensor, None] = None self.audio_pred: Union[torch.Tensor, None] = None - + if not is_latents_cached: # only return a tensor if latents are not cached - self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + self.tensor: torch.Tensor = torch.cat( + [x.tensor.unsqueeze(0) for x in self.file_items] + ) # if we have encoded latents, we concatenate them self.latents: Union[torch.Tensor, None] = None if is_latents_cached: - self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) + # this get_latent call with trigger loading all cached items from the disk + self.latents = torch.cat( + [x.get_latent().unsqueeze(0) for x in self.file_items] + ) + if any( + [x._cached_first_frame_latent is not None for x in self.file_items] + ): + self.first_frame_latents = torch.cat( + [ + x._cached_first_frame_latent.unsqueeze(0) + if x._cached_first_frame_latent is not None + else torch.zeros_like( + self.file_items[0]._cached_first_frame_latent + ).unsqueeze(0) + for x in self.file_items + ] + ) + if any([x._cached_audio_latent is not None for x in self.file_items]): + self.audio_latents = torch.cat( + [ + x._cached_audio_latent.unsqueeze(0) + if x._cached_audio_latent is not None + else torch.zeros_like( + self.file_items[0]._cached_audio_latent + ).unsqueeze(0) + for x in self.file_items + ] + ) + self.prompt_embeds: Union[PromptEmbeds, None] = None # if self.file_items[0].control_tensor is not None: # if any have a control tensor, we concatenate them @@ -188,8 +248,10 @@ class DataLoaderBatchDTO: control_tensors.append(torch.zeros_like(base_control_tensor)) else: control_tensors.append(x.control_tensor) - self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) - + self.control_tensor = torch.cat( + [x.unsqueeze(0) for x in control_tensors] + ) + # handle control tensor list if any([x.control_tensor_list is not None for x in self.file_items]): self.control_tensor_list = [] @@ -197,9 +259,10 @@ class DataLoaderBatchDTO: if x.control_tensor_list is not None: self.control_tensor_list.append(x.control_tensor_list) else: - raise Exception(f"Could not find control tensors for all file items, missing for {x.path}") - - + raise Exception( + f"Could not find control tensors for all file items, missing for {x.path}" + ) + self.inpaint_tensor: Union[torch.Tensor, None] = None if any([x.inpaint_tensor is not None for x in self.file_items]): # find one to use as a base @@ -214,9 +277,13 @@ class DataLoaderBatchDTO: inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor)) else: inpaint_tensors.append(x.inpaint_tensor) - self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors]) + self.inpaint_tensor = torch.cat( + [x.unsqueeze(0) for x in inpaint_tensors] + ) - self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] + self.loss_multiplier_list: List[float] = [ + x.loss_multiplier for x in self.file_items + ] if any([x.clip_image_tensor is not None for x in self.file_items]): # find one to use as a base @@ -228,10 +295,14 @@ class DataLoaderBatchDTO: clip_image_tensors = [] for x in self.file_items: if x.clip_image_tensor is None: - clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor)) + clip_image_tensors.append( + torch.zeros_like(base_clip_image_tensor) + ) else: clip_image_tensors.append(x.clip_image_tensor) - self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors]) + self.clip_image_tensor = torch.cat( + [x.unsqueeze(0) for x in clip_image_tensors] + ) if any([x.mask_tensor is not None for x in self.file_items]): # find one to use as a base @@ -259,10 +330,14 @@ class DataLoaderBatchDTO: unaugmented_tensor = [] for x in self.file_items: if x.unaugmented_tensor is None: - unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor)) + unaugmented_tensor.append( + torch.zeros_like(base_unaugmented_tensor) + ) else: unaugmented_tensor.append(x.unaugmented_tensor) - self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) + self.unaugmented_tensor = torch.cat( + [x.unsqueeze(0) for x in unaugmented_tensor] + ) # add unconditional tensors if any([x.unconditional_tensor is not None for x in self.file_items]): @@ -275,10 +350,14 @@ class DataLoaderBatchDTO: unconditional_tensor = [] for x in self.file_items: if x.unconditional_tensor is None: - unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor)) + unconditional_tensor.append( + torch.zeros_like(base_unconditional_tensor) + ) else: unconditional_tensor.append(x.unconditional_tensor) - self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) + self.unconditional_tensor = torch.cat( + [x.unsqueeze(0) for x in unconditional_tensor] + ) if any([x.clip_image_embeds is not None for x in self.file_items]): self.clip_image_embeds = [] @@ -288,14 +367,20 @@ class DataLoaderBatchDTO: else: raise Exception("clip_image_embeds is None for some file items") - if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]): + if any( + [x.clip_image_embeds_unconditional is not None for x in self.file_items] + ): self.clip_image_embeds_unconditional = [] for x in self.file_items: if x.clip_image_embeds_unconditional is not None: - self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) + self.clip_image_embeds_unconditional.append( + x.clip_image_embeds_unconditional + ) else: - raise Exception("clip_image_embeds_unconditional is None for some file items") - + raise Exception( + "clip_image_embeds_unconditional is None for some file items" + ) + if any([x.prompt_embeds is not None for x in self.file_items]): # find one to use as a base base_prompt_embeds = None @@ -315,7 +400,7 @@ class DataLoaderBatchDTO: y.text_embeds = [y.text_embeds] prompt_embeds_list.append(y) self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list) - + if any([x.audio_tensor is not None for x in self.file_items]): # find one to use as a base base_audio_tensor = None @@ -330,7 +415,6 @@ class DataLoaderBatchDTO: else: audio_tensors.append(x.audio_tensor) self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors]) - except Exception as e: print(e) @@ -343,18 +427,12 @@ class DataLoaderBatchDTO: return [x.network_weight for x in self.file_items] def get_caption_list( - self, - trigger=None, - to_replace_list=None, - add_if_not_present=True + self, trigger=None, to_replace_list=None, add_if_not_present=True ): return [x.caption for x in self.file_items] def get_caption_short_list( - self, - trigger=None, - to_replace_list=None, - add_if_not_present=True + self, trigger=None, to_replace_list=None, add_if_not_present=True ): return [x.caption_short for x in self.file_items] @@ -366,11 +444,13 @@ class DataLoaderBatchDTO: del self.audio_data del self.audio_target del self.audio_pred + del self.first_frame_latents + del self.audio_latents for file_item in self.file_items: file_item.cleanup() - + @property - def dataset_config(self) -> 'DatasetConfig': + def dataset_config(self) -> "DatasetConfig": if len(self.file_items) > 0: return self.file_items[0].dataset_config else: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3b542f56..e4fb95bb 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -456,8 +456,6 @@ class ImageProcessingDTOMixin: transform: Union[None, transforms.Compose], only_load_latents=False ): - if self.is_latent_cached: - raise Exception('Latent caching not supported for videos') if self.augments is not None and len(self.augments) > 0: raise Exception('Augments not supported for videos') @@ -727,9 +725,6 @@ class ImageProcessingDTOMixin: transform: Union[None, transforms.Compose], only_load_latents=False ): - if self.dataset_config.num_frames > 1: - self.load_and_process_video(transform, only_load_latents) - return # handle get_prompt_embedding if self.is_text_embedding_cached: self.load_prompt_embedding() @@ -747,6 +742,9 @@ class ImageProcessingDTOMixin: if self.has_unconditional: self.load_unconditional_image() return + if self.dataset_config.num_frames > 1: + self.load_and_process_video(transform, only_load_latents) + return try: img = Image.open(self.path) img = exif_transpose(img) @@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin: if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) self._encoded_latent: Union[torch.Tensor, None] = None + self._cached_first_frame_latent: Union[torch.Tensor, None] = None + self._cached_audio_latent: Union[torch.Tensor, None] = None self._latent_path: Union[str, None] = None self.is_latent_cached = False self.is_caching_to_disk = False @@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin: item["flip_y"] = True if self.dataset_config.num_frames > 1: item["num_frames"] = self.dataset_config.num_frames + if self.dataset_config.do_i2v: + item["do_i2v"] = True + if self.dataset_config.do_audio: + item["do_audio"] = True + if self.dataset_config.audio_normalize: + item["audio_normalize"] = True + if self.dataset_config.audio_preserve_pitch: + item["audio_preserve_pitch"] = True return item def get_latent_path(self: 'FileItemDTO', recalculate=False): @@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin: if not self.is_caching_to_memory: # we are caching on disk, don't save in memory self._encoded_latent = None + self._cached_first_frame_latent = None + self._cached_audio_latent = None else: # move it back to cpu self._encoded_latent = self._encoded_latent.to('cpu') + if self._cached_first_frame_latent is not None: + self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu') + if self._cached_audio_latent is not None: + self._cached_audio_latent = self._cached_audio_latent.to('cpu') def get_latent(self, device=None): if not self.is_latent_cached: @@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin: device='cpu' ) self._encoded_latent = state_dict['latent'] + if 'first_frame_latent' in state_dict: + self._cached_first_frame_latent = state_dict['first_frame_latent'] + if 'audio_latent' in state_dict: + self._cached_audio_latent = state_dict['audio_latent'] return self._encoded_latent @@ -1795,8 +1813,6 @@ class LatentCachingMixin: self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): - if self.dataset_config.num_frames > 1: - raise Exception("Error: caching latents is not supported for multi-frame datasets") with accelerator.main_process_first(): print_acc(f"Caching latents for {self.dataset_path}") # cache all latents to disk @@ -1839,25 +1855,50 @@ class LatentCachingMixin: # load it into memory state_dict = load_file(latent_path, device='cpu') file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) + if 'first_frame_latent' in state_dict: + file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype) + if 'audio_latent' in state_dict: + file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype) else: # not saved to disk, calculate # load the image first file_item.load_and_process_image(self.transform, only_load_latents=True) dtype = self.sd.torch_dtype device = self.sd.device_torch + state_dict = OrderedDict() + first_frame_latent = None + audio_latent = None # add batch dimension try: imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) latent = self.sd.encode_images(imgs).squeeze(0) + if to_disk: + state_dict['latent'] = latent.clone().detach().cpu() except Exception as e: print_acc(f"Error processing image: {file_item.path}") print_acc(f"Error: {str(e)}") raise e + # do first frame + if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v: + frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) + if len(frames.shape) == 4: + first_frames = frames + elif len(frames.shape) == 5: + first_frames = frames[:, 0] + else: + raise ValueError(f"Unknown frame shape {frames.shape}") + first_frame_latent = self.sd.encode_images(first_frames).squeeze(0) + if to_disk: + state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu() + + # audio + if file_item.audio_data is not None: + audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0) + if to_disk: + state_dict['audio_latent'] = audio_latent.clone().detach().cpu() + # save_latent if to_disk: - state_dict = OrderedDict([ - ('latent', latent.clone().detach().cpu()), - ]) # metadata meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) os.makedirs(os.path.dirname(latent_path), exist_ok=True) @@ -1866,17 +1907,18 @@ class LatentCachingMixin: if to_memory: # keep it in memory file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) + if first_frame_latent is not None: + file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype) + if audio_latent is not None: + file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype) del imgs del latent del file_item.tensor + file_item.cleanup() - # flush(garbage_collect=False) file_item.is_latent_cached = True i += 1 - # flush every 100 - # if i % 100 == 0: - # flush() # restore device state self.sd.restore_device_state() diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 23bd9a9c..1a9f23e7 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1122,6 +1122,10 @@ class BaseModel: latents = latents.to(device, dtype=dtype) return latents + + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} + raise NotImplementedError("Audio encoding not implemented for this model.") def decode_latents( self, diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index d2b34b9f..f397c3d7 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2550,6 +2550,10 @@ class StableDiffusion: latents = latents.to(device, dtype=dtype) return latents + + def encode_audio(self, audio_data_list): + # audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)} + raise NotImplementedError("Audio encoding not implemented for this model.") def decode_latents( self, diff --git a/version.py b/version.py index 63d7599f..82b3d72e 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.18" +VERSION = "0.7.19"