Do caching of latents, first frame and audio when caching latents for LTX2

This commit is contained in:
Jaret Burkett
2026-01-14 11:05:23 -07:00
parent 64fe29b182
commit 73dedbf662
6 changed files with 324 additions and 127 deletions

View File

@@ -456,8 +456,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.is_latent_cached:
raise Exception('Latent caching not supported for videos')
if self.augments is not None and len(self.augments) > 0:
raise Exception('Augments not supported for videos')
@@ -727,9 +725,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
@@ -747,6 +742,9 @@ class ImageProcessingDTOMixin:
if self.has_unconditional:
self.load_unconditional_image()
return
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
try:
img = Image.open(self.path)
img = exif_transpose(img)
@@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin:
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self._encoded_latent: Union[torch.Tensor, None] = None
self._cached_first_frame_latent: Union[torch.Tensor, None] = None
self._cached_audio_latent: Union[torch.Tensor, None] = None
self._latent_path: Union[str, None] = None
self.is_latent_cached = False
self.is_caching_to_disk = False
@@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin:
item["flip_y"] = True
if self.dataset_config.num_frames > 1:
item["num_frames"] = self.dataset_config.num_frames
if self.dataset_config.do_i2v:
item["do_i2v"] = True
if self.dataset_config.do_audio:
item["do_audio"] = True
if self.dataset_config.audio_normalize:
item["audio_normalize"] = True
if self.dataset_config.audio_preserve_pitch:
item["audio_preserve_pitch"] = True
return item
def get_latent_path(self: 'FileItemDTO', recalculate=False):
@@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin:
if not self.is_caching_to_memory:
# we are caching on disk, don't save in memory
self._encoded_latent = None
self._cached_first_frame_latent = None
self._cached_audio_latent = None
else:
# move it back to cpu
self._encoded_latent = self._encoded_latent.to('cpu')
if self._cached_first_frame_latent is not None:
self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu')
if self._cached_audio_latent is not None:
self._cached_audio_latent = self._cached_audio_latent.to('cpu')
def get_latent(self, device=None):
if not self.is_latent_cached:
@@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin:
device='cpu'
)
self._encoded_latent = state_dict['latent']
if 'first_frame_latent' in state_dict:
self._cached_first_frame_latent = state_dict['first_frame_latent']
if 'audio_latent' in state_dict:
self._cached_audio_latent = state_dict['audio_latent']
return self._encoded_latent
@@ -1795,8 +1813,6 @@ class LatentCachingMixin:
self.latent_cache = {}
def cache_latents_all_latents(self: 'AiToolkitDataset'):
if self.dataset_config.num_frames > 1:
raise Exception("Error: caching latents is not supported for multi-frame datasets")
with accelerator.main_process_first():
print_acc(f"Caching latents for {self.dataset_path}")
# cache all latents to disk
@@ -1839,25 +1855,50 @@ class LatentCachingMixin:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'first_frame_latent' in state_dict:
file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'audio_latent' in state_dict:
file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
state_dict = OrderedDict()
first_frame_latent = None
audio_latent = None
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
if to_disk:
state_dict['latent'] = latent.clone().detach().cpu()
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# do first frame
if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v:
frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
first_frame_latent = self.sd.encode_images(first_frames).squeeze(0)
if to_disk:
state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu()
# audio
if file_item.audio_data is not None:
audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0)
if to_disk:
state_dict['audio_latent'] = audio_latent.clone().detach().cpu()
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
@@ -1866,17 +1907,18 @@ class LatentCachingMixin:
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
if first_frame_latent is not None:
file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype)
if audio_latent is not None:
file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
file_item.cleanup()
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()