Do caching of latents, first frame and audio when caching latents for LTX2

This commit is contained in:
Jaret Burkett
2026-01-14 11:05:23 -07:00
parent 64fe29b182
commit 73dedbf662
6 changed files with 324 additions and 127 deletions

View File

@@ -1,24 +1,31 @@
import os
import weakref
from _weakref import ReferenceType
from typing import TYPE_CHECKING, List, Union
import cv2
import torch
import random
from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
from toolkit.dataloader_mixins import (
CaptionProcessingDTOMixin,
ImageProcessingDTOMixin,
LatentCachingFileItemDTOMixin,
ControlFileItemDTOMixin,
ArgBreakMixin,
PoiFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
ClipImageFileItemDTOMixin,
InpaintControlFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
)
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.stable_diffusion_model import StableDiffusion
printed_messages = []
@@ -45,54 +52,62 @@ class FileItemDTO(
ArgBreakMixin,
):
def __init__(self, *args, **kwargs):
self.path = kwargs.get('path', '')
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.path = kwargs.get("path", "")
self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None)
self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {})
dataset_root = kwargs.get('dataset_root', None)
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
size_database = kwargs.get("size_database", {})
dataset_root = kwargs.get("dataset_root", None)
self.encode_control_in_text_embeddings = kwargs.get(
"encode_control_in_text_embeddings", False
)
if dataset_root is not None:
# remove dataset root from path
file_key = self.path.replace(dataset_root, '')
file_key = self.path.replace(dataset_root, "")
else:
file_key = os.path.basename(self.path)
file_signature = get_quick_signature_string(self.path)
if file_signature is None:
raise Exception("Error: Could not get file signature for {self.path}")
use_db_entry = False
if file_key in size_database:
db_entry = size_database[file_key]
if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature:
if (
db_entry is not None
and len(db_entry) >= 3
and db_entry[2] == file_signature
):
use_db_entry = True
if use_db_entry:
w, h, _ = size_database[file_key]
elif self.is_video:
# Open the video file
video = cv2.VideoCapture(self.path)
# Check if video opened successfully
if not video.isOpened():
raise Exception(f"Error: Could not open video file {self.path}")
# Get width and height
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
w, h = width, height
# Release the video capture object immediately
video.release()
size_database[file_key] = (width, height, file_signature)
else:
if self.dataset_config.fast_image_size:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
try:
w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
print_once(
f"Warning: Some images in the dataset cannot be fast read. "
+ f"This process is faster for png, jpeg"
)
img = exif_transpose(Image.open(self.path))
w, h = img.size
else:
@@ -101,21 +116,25 @@ class FileItemDTO(
size_database[file_key] = (w, h, file_signature)
self.width: int = w
self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
self.dataloader_transforms = kwargs.get("dataloader_transforms", None)
super().__init__(*args, **kwargs)
# self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None)
self.raw_caption: str = kwargs.get("raw_caption", None)
# we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
self.scale_to_width: int = kwargs.get(
"scale_to_width", int(self.width * self.dataset_config.scale)
)
self.scale_to_height: int = kwargs.get(
"scale_to_height", int(self.height * self.dataset_config.scale)
)
# crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0)
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.crop_x: int = kwargs.get("crop_x", 0)
self.crop_y: int = kwargs.get("crop_y", 0)
self.crop_width: int = kwargs.get("crop_width", self.scale_to_width)
self.crop_height: int = kwargs.get("crop_height", self.scale_to_height)
self.flip_x: bool = kwargs.get("flip_x", False)
self.flip_y: bool = kwargs.get("flip_x", False)
self.augments: List[str] = self.dataset_config.augments
self.loss_multiplier: float = self.dataset_config.loss_multiplier
@@ -142,9 +161,8 @@ class FileItemDTO(
class DataLoaderBatchDTO:
def __init__(self, **kwargs):
try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None)
is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,22 +174,64 @@ class DataLoaderBatchDTO:
self.unconditional_latents: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[List[dict], None] = None
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
self.sigmas: Union[torch.Tensor, None] = (
None # can be added elseware and passed along training code
)
self.extra_values: Union[torch.Tensor, None] = (
torch.tensor([x.extra_values for x in self.file_items])
if len(self.file_items[0].extra_values) > 0
else None
)
self.audio_data: Union[List, None] = (
[x.audio_data for x in self.file_items]
if self.file_items[0].audio_data is not None
else None
)
self.audio_tensor: Union[torch.Tensor, None] = None
self.first_frame_latents: Union[torch.Tensor, None] = None
self.audio_latents: Union[torch.Tensor, None] = None
# just for holding noise and preds during training
self.audio_target: Union[torch.Tensor, None] = None
self.audio_pred: Union[torch.Tensor, None] = None
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
self.tensor: torch.Tensor = torch.cat(
[x.tensor.unsqueeze(0) for x in self.file_items]
)
# if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None
if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
# this get_latent call with trigger loading all cached items from the disk
self.latents = torch.cat(
[x.get_latent().unsqueeze(0) for x in self.file_items]
)
if any(
[x._cached_first_frame_latent is not None for x in self.file_items]
):
self.first_frame_latents = torch.cat(
[
x._cached_first_frame_latent.unsqueeze(0)
if x._cached_first_frame_latent is not None
else torch.zeros_like(
self.file_items[0]._cached_first_frame_latent
).unsqueeze(0)
for x in self.file_items
]
)
if any([x._cached_audio_latent is not None for x in self.file_items]):
self.audio_latents = torch.cat(
[
x._cached_audio_latent.unsqueeze(0)
if x._cached_audio_latent is not None
else torch.zeros_like(
self.file_items[0]._cached_audio_latent
).unsqueeze(0)
for x in self.file_items
]
)
self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them
@@ -188,8 +248,10 @@ class DataLoaderBatchDTO:
control_tensors.append(torch.zeros_like(base_control_tensor))
else:
control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
self.control_tensor = torch.cat(
[x.unsqueeze(0) for x in control_tensors]
)
# handle control tensor list
if any([x.control_tensor_list is not None for x in self.file_items]):
self.control_tensor_list = []
@@ -197,9 +259,10 @@ class DataLoaderBatchDTO:
if x.control_tensor_list is not None:
self.control_tensor_list.append(x.control_tensor_list)
else:
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
raise Exception(
f"Could not find control tensors for all file items, missing for {x.path}"
)
self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]):
# find one to use as a base
@@ -214,9 +277,13 @@ class DataLoaderBatchDTO:
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
else:
inpaint_tensors.append(x.inpaint_tensor)
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
self.inpaint_tensor = torch.cat(
[x.unsqueeze(0) for x in inpaint_tensors]
)
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
self.loss_multiplier_list: List[float] = [
x.loss_multiplier for x in self.file_items
]
if any([x.clip_image_tensor is not None for x in self.file_items]):
# find one to use as a base
@@ -228,10 +295,14 @@ class DataLoaderBatchDTO:
clip_image_tensors = []
for x in self.file_items:
if x.clip_image_tensor is None:
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
clip_image_tensors.append(
torch.zeros_like(base_clip_image_tensor)
)
else:
clip_image_tensors.append(x.clip_image_tensor)
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
self.clip_image_tensor = torch.cat(
[x.unsqueeze(0) for x in clip_image_tensors]
)
if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base
@@ -259,10 +330,14 @@ class DataLoaderBatchDTO:
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
unaugmented_tensor.append(
torch.zeros_like(base_unaugmented_tensor)
)
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
self.unaugmented_tensor = torch.cat(
[x.unsqueeze(0) for x in unaugmented_tensor]
)
# add unconditional tensors
if any([x.unconditional_tensor is not None for x in self.file_items]):
@@ -275,10 +350,14 @@ class DataLoaderBatchDTO:
unconditional_tensor = []
for x in self.file_items:
if x.unconditional_tensor is None:
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
unconditional_tensor.append(
torch.zeros_like(base_unconditional_tensor)
)
else:
unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
self.unconditional_tensor = torch.cat(
[x.unsqueeze(0) for x in unconditional_tensor]
)
if any([x.clip_image_embeds is not None for x in self.file_items]):
self.clip_image_embeds = []
@@ -288,14 +367,20 @@ class DataLoaderBatchDTO:
else:
raise Exception("clip_image_embeds is None for some file items")
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
if any(
[x.clip_image_embeds_unconditional is not None for x in self.file_items]
):
self.clip_image_embeds_unconditional = []
for x in self.file_items:
if x.clip_image_embeds_unconditional is not None:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
self.clip_image_embeds_unconditional.append(
x.clip_image_embeds_unconditional
)
else:
raise Exception("clip_image_embeds_unconditional is None for some file items")
raise Exception(
"clip_image_embeds_unconditional is None for some file items"
)
if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base
base_prompt_embeds = None
@@ -315,7 +400,7 @@ class DataLoaderBatchDTO:
y.text_embeds = [y.text_embeds]
prompt_embeds_list.append(y)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
if any([x.audio_tensor is not None for x in self.file_items]):
# find one to use as a base
base_audio_tensor = None
@@ -330,7 +415,6 @@ class DataLoaderBatchDTO:
else:
audio_tensors.append(x.audio_tensor)
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
except Exception as e:
print(e)
@@ -343,18 +427,12 @@ class DataLoaderBatchDTO:
return [x.network_weight for x in self.file_items]
def get_caption_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
self, trigger=None, to_replace_list=None, add_if_not_present=True
):
return [x.caption for x in self.file_items]
def get_caption_short_list(
self,
trigger=None,
to_replace_list=None,
add_if_not_present=True
self, trigger=None, to_replace_list=None, add_if_not_present=True
):
return [x.caption_short for x in self.file_items]
@@ -366,11 +444,13 @@ class DataLoaderBatchDTO:
del self.audio_data
del self.audio_target
del self.audio_pred
del self.first_frame_latents
del self.audio_latents
for file_item in self.file_items:
file_item.cleanup()
@property
def dataset_config(self) -> 'DatasetConfig':
def dataset_config(self) -> "DatasetConfig":
if len(self.file_items) > 0:
return self.file_items[0].dataset_config
else:

View File

@@ -456,8 +456,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.is_latent_cached:
raise Exception('Latent caching not supported for videos')
if self.augments is not None and len(self.augments) > 0:
raise Exception('Augments not supported for videos')
@@ -727,9 +725,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose],
only_load_latents=False
):
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding
if self.is_text_embedding_cached:
self.load_prompt_embedding()
@@ -747,6 +742,9 @@ class ImageProcessingDTOMixin:
if self.has_unconditional:
self.load_unconditional_image()
return
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
try:
img = Image.open(self.path)
img = exif_transpose(img)
@@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin:
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self._encoded_latent: Union[torch.Tensor, None] = None
self._cached_first_frame_latent: Union[torch.Tensor, None] = None
self._cached_audio_latent: Union[torch.Tensor, None] = None
self._latent_path: Union[str, None] = None
self.is_latent_cached = False
self.is_caching_to_disk = False
@@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin:
item["flip_y"] = True
if self.dataset_config.num_frames > 1:
item["num_frames"] = self.dataset_config.num_frames
if self.dataset_config.do_i2v:
item["do_i2v"] = True
if self.dataset_config.do_audio:
item["do_audio"] = True
if self.dataset_config.audio_normalize:
item["audio_normalize"] = True
if self.dataset_config.audio_preserve_pitch:
item["audio_preserve_pitch"] = True
return item
def get_latent_path(self: 'FileItemDTO', recalculate=False):
@@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin:
if not self.is_caching_to_memory:
# we are caching on disk, don't save in memory
self._encoded_latent = None
self._cached_first_frame_latent = None
self._cached_audio_latent = None
else:
# move it back to cpu
self._encoded_latent = self._encoded_latent.to('cpu')
if self._cached_first_frame_latent is not None:
self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu')
if self._cached_audio_latent is not None:
self._cached_audio_latent = self._cached_audio_latent.to('cpu')
def get_latent(self, device=None):
if not self.is_latent_cached:
@@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin:
device='cpu'
)
self._encoded_latent = state_dict['latent']
if 'first_frame_latent' in state_dict:
self._cached_first_frame_latent = state_dict['first_frame_latent']
if 'audio_latent' in state_dict:
self._cached_audio_latent = state_dict['audio_latent']
return self._encoded_latent
@@ -1795,8 +1813,6 @@ class LatentCachingMixin:
self.latent_cache = {}
def cache_latents_all_latents(self: 'AiToolkitDataset'):
if self.dataset_config.num_frames > 1:
raise Exception("Error: caching latents is not supported for multi-frame datasets")
with accelerator.main_process_first():
print_acc(f"Caching latents for {self.dataset_path}")
# cache all latents to disk
@@ -1839,25 +1855,50 @@ class LatentCachingMixin:
# load it into memory
state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'first_frame_latent' in state_dict:
file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'audio_latent' in state_dict:
file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype)
else:
# not saved to disk, calculate
# load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype
device = self.sd.device_torch
state_dict = OrderedDict()
first_frame_latent = None
audio_latent = None
# add batch dimension
try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0)
if to_disk:
state_dict['latent'] = latent.clone().detach().cpu()
except Exception as e:
print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}")
raise e
# do first frame
if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v:
frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
first_frame_latent = self.sd.encode_images(first_frames).squeeze(0)
if to_disk:
state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu()
# audio
if file_item.audio_data is not None:
audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0)
if to_disk:
state_dict['audio_latent'] = audio_latent.clone().detach().cpu()
# save_latent
if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
@@ -1866,17 +1907,18 @@ class LatentCachingMixin:
if to_memory:
# keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
if first_frame_latent is not None:
file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype)
if audio_latent is not None:
file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs
del latent
del file_item.tensor
file_item.cleanup()
# flush(garbage_collect=False)
file_item.is_latent_cached = True
i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state
self.sd.restore_device_state()

View File

@@ -1122,6 +1122,10 @@ class BaseModel:
latents = latents.to(device, dtype=dtype)
return latents
def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
raise NotImplementedError("Audio encoding not implemented for this model.")
def decode_latents(
self,

View File

@@ -2550,6 +2550,10 @@ class StableDiffusion:
latents = latents.to(device, dtype=dtype)
return latents
def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
raise NotImplementedError("Audio encoding not implemented for this model.")
def decode_latents(
self,