Do caching of latents, first frame and audio when caching latents for LTX2

This commit is contained in:
Jaret Burkett
2026-01-14 11:05:23 -07:00
parent 64fe29b182
commit 73dedbf662
6 changed files with 324 additions and 127 deletions

View File

@@ -98,6 +98,41 @@ def blank_log_image_function(self, *args, **kwargs):
return return
class ComboVae(torch.nn.Module):
"""Combines video and audio VAEs for joint encoding and decoding."""
def __init__(
self,
vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Audio,
) -> None:
super().__init__()
self.vae = vae
self.audio_vae = audio_vae
@property
def device(self):
return self.vae.device
@property
def dtype(self):
return self.vae.dtype
def encode(
self,
*args,
**kwargs,
):
return self.vae.encode(*args, **kwargs)
def decode(
self,
*args,
**kwargs,
):
return self.vae.decode(*args, **kwargs)
class AudioProcessor(torch.nn.Module): class AudioProcessor(torch.nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling.""" """Converts audio waveforms to log-mel spectrograms with optional resampling."""
@@ -445,7 +480,7 @@ class LTX2Model(BaseModel):
flush() flush()
# save it to the model class # save it to the model class
self.vae = vae self.vae = ComboVae(pipe.vae, pipe.audio_vae)
self.text_encoder = text_encoder # list of text encoders self.text_encoder = text_encoder # list of text encoders
self.tokenizer = tokenizer # list of tokenizers self.tokenizer = tokenizer # list of tokenizers
self.model = pipe.transformer self.model = pipe.transformer
@@ -467,10 +502,10 @@ class LTX2Model(BaseModel):
if dtype is None: if dtype is None:
dtype = self.vae_torch_dtype dtype = self.vae_torch_dtype
if self.vae.device == torch.device("cpu"): if self.pipeline.vae.device == torch.device("cpu"):
self.vae.to(device) self.pipeline.vae.to(device)
self.vae.eval() self.pipeline.vae.eval()
self.vae.requires_grad_(False) self.pipeline.vae.requires_grad_(False)
image_list = [image.to(device, dtype=dtype) for image in image_list] image_list = [image.to(device, dtype=dtype) for image in image_list]
@@ -489,7 +524,7 @@ class LTX2Model(BaseModel):
# Stack to (B, C, T, H, W) # Stack to (B, C, T, H, W)
images = torch.stack(norm_images) images = torch.stack(norm_images)
latents = self.vae.encode(images).latent_dist.mode() latents = self.pipeline.vae.encode(images).latent_dist.mode()
# Normalize latents across the channel dimension [B, C, F, H, W] # Normalize latents across the channel dimension [B, C, F, H, W]
scaling_factor = 1.0 scaling_factor = 1.0
@@ -571,7 +606,9 @@ class LTX2Model(BaseModel):
if gen_config.ctrl_img is not None: if gen_config.ctrl_img is not None:
control_img = Image.open(gen_config.ctrl_img).convert("RGB") control_img = Image.open(gen_config.ctrl_img).convert("RGB")
# resize the control image # resize the control image
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS) control_img = control_img.resize(
(gen_config.width, gen_config.height), Image.LANCZOS
)
# add the control image to the extra dict # add the control image to the extra dict
extra["image"] = control_img extra["image"] = control_img
@@ -642,7 +679,8 @@ class LTX2Model(BaseModel):
img = video[0] img = video[0]
return img return img
def encode_audio(self, batch: "DataLoaderBatchDTO"): def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
if self.pipeline.audio_vae.device == torch.device("cpu"): if self.pipeline.audio_vae.device == torch.device("cpu"):
self.pipeline.audio_vae.to(self.device_torch) self.pipeline.audio_vae.to(self.device_torch)
@@ -650,7 +688,7 @@ class LTX2Model(BaseModel):
audio_num_frames = None audio_num_frames = None
# do them seperatly for now # do them seperatly for now
for audio_data in batch.audio_data: for audio_data in audio_data_list:
waveform = audio_data["waveform"].to( waveform = audio_data["waveform"].to(
device=self.device_torch, dtype=torch.float32 device=self.device_torch, dtype=torch.float32
) )
@@ -665,20 +703,24 @@ class LTX2Model(BaseModel):
waveform = waveform.repeat(1, 2, 1) waveform = waveform.repeat(1, 2, 1)
# Convert waveform to mel spectrogram using AudioProcessor # Convert waveform to mel spectrogram using AudioProcessor
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate) mel_spectrogram = self.audio_processor.waveform_to_mel(
waveform, waveform_sample_rate=sample_rate
)
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype) mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
# Encode mel spectrogram to latents # Encode mel spectrogram to latents
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode() latents = self.pipeline.audio_vae.encode(
mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)
).latent_dist.mode()
if audio_num_frames is None: if audio_num_frames is None:
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F]) audio_num_frames = latents.shape[2] # (latents is [B, C, T, F])
packed_latents = self.pipeline._pack_audio_latents( packed_latents = self.pipeline._pack_audio_latents(
latents, latents,
# patch_size=self.pipeline.transformer.config.audio_patch_size, # patch_size=self.pipeline.transformer.config.audio_patch_size,
# patch_size_t=self.pipeline.transformer.config.audio_patch_size_t, # patch_size_t=self.pipeline.transformer.config.audio_patch_size_t,
) # [B, L, C * M] ) # [B, L, C * M]
if output_tensor is None: if output_tensor is None:
output_tensor = packed_latents output_tensor = packed_latents
else: else:
@@ -688,7 +730,7 @@ class LTX2Model(BaseModel):
latents_mean = self.pipeline.audio_vae.latents_mean latents_mean = self.pipeline.audio_vae.latents_mean
latents_std = self.pipeline.audio_vae.latents_std latents_std = self.pipeline.audio_vae.latents_std
output_tensor = (output_tensor - latents_mean) / latents_std output_tensor = (output_tensor - latents_mean) / latents_std
return output_tensor, audio_num_frames return output_tensor
def get_noise_prediction( def get_noise_prediction(
self, self,
@@ -710,21 +752,40 @@ class LTX2Model(BaseModel):
# i2v from first frame # i2v from first frame
if batch.dataset_config.do_i2v: if batch.dataset_config.do_i2v:
# videos come in (bs, num_frames, channels, height, width) # check to see if we had it cached
# images come in (bs, channels, height, width) if batch.first_frame_latents is not None:
frames = batch.tensor init_latents = batch.first_frame_latents.to(
if len(frames.shape) == 4: self.device_torch, dtype=self.torch_dtype
first_frames = frames )
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else: else:
raise ValueError(f"Unknown frame shape {frames.shape}") # extract the first frame and encode it
# first frame doesnt have time dim, add it back # videos come in (bs, num_frames, channels, height, width)
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype) # images come in (bs, channels, height, width)
frames = batch.tensor
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
# first frame doesnt have time dim, add it back
init_latents = self.encode_images(
first_frames, device=self.device_torch, dtype=self.torch_dtype
)
# expand the latents to match video frames
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1) init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) mask_shape = (
batch_size,
1,
latent_num_frames,
latent_height,
latent_width,
)
# First condition is image latents and those should be kept clean. # First condition is image latents and those should be kept clean.
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype) conditioning_mask = torch.zeros(
mask_shape, device=self.device_torch, dtype=self.torch_dtype
)
conditioning_mask[:, :, 0] = 1.0 conditioning_mask[:, :, 0] = 1.0
# use conditioning mask to replace latents # use conditioning mask to replace latents
@@ -746,10 +807,18 @@ class LTX2Model(BaseModel):
patch_size_t=self.pipeline.transformer_temporal_patch_size, patch_size_t=self.pipeline.transformer_temporal_patch_size,
) )
if batch.audio_tensor is not None: if batch.audio_latents is not None or batch.audio_tensor is not None:
# use audio from the batch if available if batch.audio_latents is not None:
#(1, 190, 128) # we have audio latents cached
raw_audio_latents, audio_num_frames = self.encode_audio(batch) raw_audio_latents = batch.audio_latents.to(
self.device_torch, dtype=self.torch_dtype
)
else:
# we have audio waveforms to encode
# use audio from the batch if available
raw_audio_latents = self.encode_audio(batch.audio_data)
audio_num_frames = raw_audio_latents.shape[1]
# add the audio targets to the batch for loss calculation later # add the audio targets to the batch for loss calculation later
audio_noise = torch.randn_like(raw_audio_latents) audio_noise = torch.randn_like(raw_audio_latents)
batch.audio_target = (audio_noise - raw_audio_latents).detach() batch.audio_target = (audio_noise - raw_audio_latents).detach()
@@ -758,7 +827,6 @@ class LTX2Model(BaseModel):
audio_noise, audio_noise,
timestep, timestep,
).to(self.device_torch, dtype=self.torch_dtype) ).to(self.device_torch, dtype=self.torch_dtype)
else: else:
# no audio # no audio
num_mel_bins = self.pipeline.audio_vae.config.mel_bins num_mel_bins = self.pipeline.audio_vae.config.mel_bins
@@ -766,7 +834,6 @@ class LTX2Model(BaseModel):
num_channels_latents_audio = ( num_channels_latents_audio = (
self.pipeline.audio_vae.config.latent_channels self.pipeline.audio_vae.config.latent_channels
) )
# audio latents are (1, 126, 128), audio_num_frames = 126
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents( audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
batch_size, batch_size,
num_channels_latents=num_channels_latents_audio, num_channels_latents=num_channels_latents_audio,

View File

@@ -1,24 +1,31 @@
import os import os
import weakref
from _weakref import ReferenceType
from typing import TYPE_CHECKING, List, Union from typing import TYPE_CHECKING, List, Union
import cv2 import cv2
import torch import torch
import random
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from toolkit import image_utils from toolkit import image_utils
from toolkit.basic import get_quick_signature_string from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ from toolkit.dataloader_mixins import (
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ CaptionProcessingDTOMixin,
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin ImageProcessingDTOMixin,
LatentCachingFileItemDTOMixin,
ControlFileItemDTOMixin,
ArgBreakMixin,
PoiFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
UnconditionalFileItemDTOMixin,
ClipImageFileItemDTOMixin,
InpaintControlFileItemDTOMixin,
TextEmbeddingFileItemDTOMixin,
)
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
from toolkit.stable_diffusion_model import StableDiffusion
printed_messages = [] printed_messages = []
@@ -45,15 +52,17 @@ class FileItemDTO(
ArgBreakMixin, ArgBreakMixin,
): ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.path = kwargs.get('path', '') self.path = kwargs.get("path", "")
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None)
self.is_video = self.dataset_config.num_frames > 1 self.is_video = self.dataset_config.num_frames > 1
size_database = kwargs.get('size_database', {}) size_database = kwargs.get("size_database", {})
dataset_root = kwargs.get('dataset_root', None) dataset_root = kwargs.get("dataset_root", None)
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False) self.encode_control_in_text_embeddings = kwargs.get(
"encode_control_in_text_embeddings", False
)
if dataset_root is not None: if dataset_root is not None:
# remove dataset root from path # remove dataset root from path
file_key = self.path.replace(dataset_root, '') file_key = self.path.replace(dataset_root, "")
else: else:
file_key = os.path.basename(self.path) file_key = os.path.basename(self.path)
@@ -64,7 +73,11 @@ class FileItemDTO(
use_db_entry = False use_db_entry = False
if file_key in size_database: if file_key in size_database:
db_entry = size_database[file_key] db_entry = size_database[file_key]
if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature: if (
db_entry is not None
and len(db_entry) >= 3
and db_entry[2] == file_signature
):
use_db_entry = True use_db_entry = True
if use_db_entry: if use_db_entry:
@@ -87,12 +100,14 @@ class FileItemDTO(
size_database[file_key] = (width, height, file_signature) size_database[file_key] = (width, height, file_signature)
else: else:
if self.dataset_config.fast_image_size: if self.dataset_config.fast_image_size:
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default. # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
try: try:
w, h = image_utils.get_image_size(self.path) w, h = image_utils.get_image_size(self.path)
except image_utils.UnknownImageFormat: except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ print_once(
f'This process is faster for png, jpeg') f"Warning: Some images in the dataset cannot be fast read. "
+ f"This process is faster for png, jpeg"
)
img = exif_transpose(Image.open(self.path)) img = exif_transpose(Image.open(self.path))
w, h = img.size w, h = img.size
else: else:
@@ -101,21 +116,25 @@ class FileItemDTO(
size_database[file_key] = (w, h, file_signature) size_database[file_key] = (w, h, file_signature)
self.width: int = w self.width: int = w
self.height: int = h self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None) self.dataloader_transforms = kwargs.get("dataloader_transforms", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# self.caption_path: str = kwargs.get('caption_path', None) # self.caption_path: str = kwargs.get('caption_path', None)
self.raw_caption: str = kwargs.get('raw_caption', None) self.raw_caption: str = kwargs.get("raw_caption", None)
# we scale first, then crop # we scale first, then crop
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale)) self.scale_to_width: int = kwargs.get(
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale)) "scale_to_width", int(self.width * self.dataset_config.scale)
)
self.scale_to_height: int = kwargs.get(
"scale_to_height", int(self.height * self.dataset_config.scale)
)
# crop values are from scaled size # crop values are from scaled size
self.crop_x: int = kwargs.get('crop_x', 0) self.crop_x: int = kwargs.get("crop_x", 0)
self.crop_y: int = kwargs.get('crop_y', 0) self.crop_y: int = kwargs.get("crop_y", 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) self.crop_width: int = kwargs.get("crop_width", self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) self.crop_height: int = kwargs.get("crop_height", self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False) self.flip_x: bool = kwargs.get("flip_x", False)
self.flip_y: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get("flip_x", False)
self.augments: List[str] = self.dataset_config.augments self.augments: List[str] = self.dataset_config.augments
self.loss_multiplier: float = self.dataset_config.loss_multiplier self.loss_multiplier: float = self.dataset_config.loss_multiplier
@@ -142,9 +161,8 @@ class FileItemDTO(
class DataLoaderBatchDTO: class DataLoaderBatchDTO:
def __init__(self, **kwargs): def __init__(self, **kwargs):
try: try:
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None)
is_latents_cached = self.file_items[0].is_latent_cached is_latents_cached = self.file_items[0].is_latent_cached
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
self.tensor: Union[torch.Tensor, None] = None self.tensor: Union[torch.Tensor, None] = None
self.latents: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None
@@ -156,10 +174,22 @@ class DataLoaderBatchDTO:
self.unconditional_latents: Union[torch.Tensor, None] = None self.unconditional_latents: Union[torch.Tensor, None] = None
self.clip_image_embeds: Union[List[dict], None] = None self.clip_image_embeds: Union[List[dict], None] = None
self.clip_image_embeds_unconditional: Union[List[dict], None] = None self.clip_image_embeds_unconditional: Union[List[dict], None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code self.sigmas: Union[torch.Tensor, None] = (
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None None # can be added elseware and passed along training code
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None )
self.extra_values: Union[torch.Tensor, None] = (
torch.tensor([x.extra_values for x in self.file_items])
if len(self.file_items[0].extra_values) > 0
else None
)
self.audio_data: Union[List, None] = (
[x.audio_data for x in self.file_items]
if self.file_items[0].audio_data is not None
else None
)
self.audio_tensor: Union[torch.Tensor, None] = None self.audio_tensor: Union[torch.Tensor, None] = None
self.first_frame_latents: Union[torch.Tensor, None] = None
self.audio_latents: Union[torch.Tensor, None] = None
# just for holding noise and preds during training # just for holding noise and preds during training
self.audio_target: Union[torch.Tensor, None] = None self.audio_target: Union[torch.Tensor, None] = None
@@ -167,11 +197,41 @@ class DataLoaderBatchDTO:
if not is_latents_cached: if not is_latents_cached:
# only return a tensor if latents are not cached # only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) self.tensor: torch.Tensor = torch.cat(
[x.tensor.unsqueeze(0) for x in self.file_items]
)
# if we have encoded latents, we concatenate them # if we have encoded latents, we concatenate them
self.latents: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None
if is_latents_cached: if is_latents_cached:
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) # this get_latent call with trigger loading all cached items from the disk
self.latents = torch.cat(
[x.get_latent().unsqueeze(0) for x in self.file_items]
)
if any(
[x._cached_first_frame_latent is not None for x in self.file_items]
):
self.first_frame_latents = torch.cat(
[
x._cached_first_frame_latent.unsqueeze(0)
if x._cached_first_frame_latent is not None
else torch.zeros_like(
self.file_items[0]._cached_first_frame_latent
).unsqueeze(0)
for x in self.file_items
]
)
if any([x._cached_audio_latent is not None for x in self.file_items]):
self.audio_latents = torch.cat(
[
x._cached_audio_latent.unsqueeze(0)
if x._cached_audio_latent is not None
else torch.zeros_like(
self.file_items[0]._cached_audio_latent
).unsqueeze(0)
for x in self.file_items
]
)
self.prompt_embeds: Union[PromptEmbeds, None] = None self.prompt_embeds: Union[PromptEmbeds, None] = None
# if self.file_items[0].control_tensor is not None: # if self.file_items[0].control_tensor is not None:
# if any have a control tensor, we concatenate them # if any have a control tensor, we concatenate them
@@ -188,7 +248,9 @@ class DataLoaderBatchDTO:
control_tensors.append(torch.zeros_like(base_control_tensor)) control_tensors.append(torch.zeros_like(base_control_tensor))
else: else:
control_tensors.append(x.control_tensor) control_tensors.append(x.control_tensor)
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) self.control_tensor = torch.cat(
[x.unsqueeze(0) for x in control_tensors]
)
# handle control tensor list # handle control tensor list
if any([x.control_tensor_list is not None for x in self.file_items]): if any([x.control_tensor_list is not None for x in self.file_items]):
@@ -197,8 +259,9 @@ class DataLoaderBatchDTO:
if x.control_tensor_list is not None: if x.control_tensor_list is not None:
self.control_tensor_list.append(x.control_tensor_list) self.control_tensor_list.append(x.control_tensor_list)
else: else:
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}") raise Exception(
f"Could not find control tensors for all file items, missing for {x.path}"
)
self.inpaint_tensor: Union[torch.Tensor, None] = None self.inpaint_tensor: Union[torch.Tensor, None] = None
if any([x.inpaint_tensor is not None for x in self.file_items]): if any([x.inpaint_tensor is not None for x in self.file_items]):
@@ -214,9 +277,13 @@ class DataLoaderBatchDTO:
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor)) inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
else: else:
inpaint_tensors.append(x.inpaint_tensor) inpaint_tensors.append(x.inpaint_tensor)
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors]) self.inpaint_tensor = torch.cat(
[x.unsqueeze(0) for x in inpaint_tensors]
)
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] self.loss_multiplier_list: List[float] = [
x.loss_multiplier for x in self.file_items
]
if any([x.clip_image_tensor is not None for x in self.file_items]): if any([x.clip_image_tensor is not None for x in self.file_items]):
# find one to use as a base # find one to use as a base
@@ -228,10 +295,14 @@ class DataLoaderBatchDTO:
clip_image_tensors = [] clip_image_tensors = []
for x in self.file_items: for x in self.file_items:
if x.clip_image_tensor is None: if x.clip_image_tensor is None:
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor)) clip_image_tensors.append(
torch.zeros_like(base_clip_image_tensor)
)
else: else:
clip_image_tensors.append(x.clip_image_tensor) clip_image_tensors.append(x.clip_image_tensor)
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors]) self.clip_image_tensor = torch.cat(
[x.unsqueeze(0) for x in clip_image_tensors]
)
if any([x.mask_tensor is not None for x in self.file_items]): if any([x.mask_tensor is not None for x in self.file_items]):
# find one to use as a base # find one to use as a base
@@ -259,10 +330,14 @@ class DataLoaderBatchDTO:
unaugmented_tensor = [] unaugmented_tensor = []
for x in self.file_items: for x in self.file_items:
if x.unaugmented_tensor is None: if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor)) unaugmented_tensor.append(
torch.zeros_like(base_unaugmented_tensor)
)
else: else:
unaugmented_tensor.append(x.unaugmented_tensor) unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) self.unaugmented_tensor = torch.cat(
[x.unsqueeze(0) for x in unaugmented_tensor]
)
# add unconditional tensors # add unconditional tensors
if any([x.unconditional_tensor is not None for x in self.file_items]): if any([x.unconditional_tensor is not None for x in self.file_items]):
@@ -275,10 +350,14 @@ class DataLoaderBatchDTO:
unconditional_tensor = [] unconditional_tensor = []
for x in self.file_items: for x in self.file_items:
if x.unconditional_tensor is None: if x.unconditional_tensor is None:
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor)) unconditional_tensor.append(
torch.zeros_like(base_unconditional_tensor)
)
else: else:
unconditional_tensor.append(x.unconditional_tensor) unconditional_tensor.append(x.unconditional_tensor)
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) self.unconditional_tensor = torch.cat(
[x.unsqueeze(0) for x in unconditional_tensor]
)
if any([x.clip_image_embeds is not None for x in self.file_items]): if any([x.clip_image_embeds is not None for x in self.file_items]):
self.clip_image_embeds = [] self.clip_image_embeds = []
@@ -288,13 +367,19 @@ class DataLoaderBatchDTO:
else: else:
raise Exception("clip_image_embeds is None for some file items") raise Exception("clip_image_embeds is None for some file items")
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]): if any(
[x.clip_image_embeds_unconditional is not None for x in self.file_items]
):
self.clip_image_embeds_unconditional = [] self.clip_image_embeds_unconditional = []
for x in self.file_items: for x in self.file_items:
if x.clip_image_embeds_unconditional is not None: if x.clip_image_embeds_unconditional is not None:
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional) self.clip_image_embeds_unconditional.append(
x.clip_image_embeds_unconditional
)
else: else:
raise Exception("clip_image_embeds_unconditional is None for some file items") raise Exception(
"clip_image_embeds_unconditional is None for some file items"
)
if any([x.prompt_embeds is not None for x in self.file_items]): if any([x.prompt_embeds is not None for x in self.file_items]):
# find one to use as a base # find one to use as a base
@@ -331,7 +416,6 @@ class DataLoaderBatchDTO:
audio_tensors.append(x.audio_tensor) audio_tensors.append(x.audio_tensor)
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors]) self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
@@ -343,18 +427,12 @@ class DataLoaderBatchDTO:
return [x.network_weight for x in self.file_items] return [x.network_weight for x in self.file_items]
def get_caption_list( def get_caption_list(
self, self, trigger=None, to_replace_list=None, add_if_not_present=True
trigger=None,
to_replace_list=None,
add_if_not_present=True
): ):
return [x.caption for x in self.file_items] return [x.caption for x in self.file_items]
def get_caption_short_list( def get_caption_short_list(
self, self, trigger=None, to_replace_list=None, add_if_not_present=True
trigger=None,
to_replace_list=None,
add_if_not_present=True
): ):
return [x.caption_short for x in self.file_items] return [x.caption_short for x in self.file_items]
@@ -366,11 +444,13 @@ class DataLoaderBatchDTO:
del self.audio_data del self.audio_data
del self.audio_target del self.audio_target
del self.audio_pred del self.audio_pred
del self.first_frame_latents
del self.audio_latents
for file_item in self.file_items: for file_item in self.file_items:
file_item.cleanup() file_item.cleanup()
@property @property
def dataset_config(self) -> 'DatasetConfig': def dataset_config(self) -> "DatasetConfig":
if len(self.file_items) > 0: if len(self.file_items) > 0:
return self.file_items[0].dataset_config return self.file_items[0].dataset_config
else: else:

View File

@@ -456,8 +456,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose], transform: Union[None, transforms.Compose],
only_load_latents=False only_load_latents=False
): ):
if self.is_latent_cached:
raise Exception('Latent caching not supported for videos')
if self.augments is not None and len(self.augments) > 0: if self.augments is not None and len(self.augments) > 0:
raise Exception('Augments not supported for videos') raise Exception('Augments not supported for videos')
@@ -727,9 +725,6 @@ class ImageProcessingDTOMixin:
transform: Union[None, transforms.Compose], transform: Union[None, transforms.Compose],
only_load_latents=False only_load_latents=False
): ):
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
# handle get_prompt_embedding # handle get_prompt_embedding
if self.is_text_embedding_cached: if self.is_text_embedding_cached:
self.load_prompt_embedding() self.load_prompt_embedding()
@@ -747,6 +742,9 @@ class ImageProcessingDTOMixin:
if self.has_unconditional: if self.has_unconditional:
self.load_unconditional_image() self.load_unconditional_image()
return return
if self.dataset_config.num_frames > 1:
self.load_and_process_video(transform, only_load_latents)
return
try: try:
img = Image.open(self.path) img = Image.open(self.path)
img = exif_transpose(img) img = exif_transpose(img)
@@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin:
if hasattr(super(), '__init__'): if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._encoded_latent: Union[torch.Tensor, None] = None self._encoded_latent: Union[torch.Tensor, None] = None
self._cached_first_frame_latent: Union[torch.Tensor, None] = None
self._cached_audio_latent: Union[torch.Tensor, None] = None
self._latent_path: Union[str, None] = None self._latent_path: Union[str, None] = None
self.is_latent_cached = False self.is_latent_cached = False
self.is_caching_to_disk = False self.is_caching_to_disk = False
@@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin:
item["flip_y"] = True item["flip_y"] = True
if self.dataset_config.num_frames > 1: if self.dataset_config.num_frames > 1:
item["num_frames"] = self.dataset_config.num_frames item["num_frames"] = self.dataset_config.num_frames
if self.dataset_config.do_i2v:
item["do_i2v"] = True
if self.dataset_config.do_audio:
item["do_audio"] = True
if self.dataset_config.audio_normalize:
item["audio_normalize"] = True
if self.dataset_config.audio_preserve_pitch:
item["audio_preserve_pitch"] = True
return item return item
def get_latent_path(self: 'FileItemDTO', recalculate=False): def get_latent_path(self: 'FileItemDTO', recalculate=False):
@@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin:
if not self.is_caching_to_memory: if not self.is_caching_to_memory:
# we are caching on disk, don't save in memory # we are caching on disk, don't save in memory
self._encoded_latent = None self._encoded_latent = None
self._cached_first_frame_latent = None
self._cached_audio_latent = None
else: else:
# move it back to cpu # move it back to cpu
self._encoded_latent = self._encoded_latent.to('cpu') self._encoded_latent = self._encoded_latent.to('cpu')
if self._cached_first_frame_latent is not None:
self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu')
if self._cached_audio_latent is not None:
self._cached_audio_latent = self._cached_audio_latent.to('cpu')
def get_latent(self, device=None): def get_latent(self, device=None):
if not self.is_latent_cached: if not self.is_latent_cached:
@@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin:
device='cpu' device='cpu'
) )
self._encoded_latent = state_dict['latent'] self._encoded_latent = state_dict['latent']
if 'first_frame_latent' in state_dict:
self._cached_first_frame_latent = state_dict['first_frame_latent']
if 'audio_latent' in state_dict:
self._cached_audio_latent = state_dict['audio_latent']
return self._encoded_latent return self._encoded_latent
@@ -1795,8 +1813,6 @@ class LatentCachingMixin:
self.latent_cache = {} self.latent_cache = {}
def cache_latents_all_latents(self: 'AiToolkitDataset'): def cache_latents_all_latents(self: 'AiToolkitDataset'):
if self.dataset_config.num_frames > 1:
raise Exception("Error: caching latents is not supported for multi-frame datasets")
with accelerator.main_process_first(): with accelerator.main_process_first():
print_acc(f"Caching latents for {self.dataset_path}") print_acc(f"Caching latents for {self.dataset_path}")
# cache all latents to disk # cache all latents to disk
@@ -1839,25 +1855,50 @@ class LatentCachingMixin:
# load it into memory # load it into memory
state_dict = load_file(latent_path, device='cpu') state_dict = load_file(latent_path, device='cpu')
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'first_frame_latent' in state_dict:
file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype)
if 'audio_latent' in state_dict:
file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype)
else: else:
# not saved to disk, calculate # not saved to disk, calculate
# load the image first # load the image first
file_item.load_and_process_image(self.transform, only_load_latents=True) file_item.load_and_process_image(self.transform, only_load_latents=True)
dtype = self.sd.torch_dtype dtype = self.sd.torch_dtype
device = self.sd.device_torch device = self.sd.device_torch
state_dict = OrderedDict()
first_frame_latent = None
audio_latent = None
# add batch dimension # add batch dimension
try: try:
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
latent = self.sd.encode_images(imgs).squeeze(0) latent = self.sd.encode_images(imgs).squeeze(0)
if to_disk:
state_dict['latent'] = latent.clone().detach().cpu()
except Exception as e: except Exception as e:
print_acc(f"Error processing image: {file_item.path}") print_acc(f"Error processing image: {file_item.path}")
print_acc(f"Error: {str(e)}") print_acc(f"Error: {str(e)}")
raise e raise e
# do first frame
if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v:
frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
if len(frames.shape) == 4:
first_frames = frames
elif len(frames.shape) == 5:
first_frames = frames[:, 0]
else:
raise ValueError(f"Unknown frame shape {frames.shape}")
first_frame_latent = self.sd.encode_images(first_frames).squeeze(0)
if to_disk:
state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu()
# audio
if file_item.audio_data is not None:
audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0)
if to_disk:
state_dict['audio_latent'] = audio_latent.clone().detach().cpu()
# save_latent # save_latent
if to_disk: if to_disk:
state_dict = OrderedDict([
('latent', latent.clone().detach().cpu()),
])
# metadata # metadata
meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
os.makedirs(os.path.dirname(latent_path), exist_ok=True) os.makedirs(os.path.dirname(latent_path), exist_ok=True)
@@ -1866,17 +1907,18 @@ class LatentCachingMixin:
if to_memory: if to_memory:
# keep it in memory # keep it in memory
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
if first_frame_latent is not None:
file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype)
if audio_latent is not None:
file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype)
del imgs del imgs
del latent del latent
del file_item.tensor del file_item.tensor
file_item.cleanup()
# flush(garbage_collect=False)
file_item.is_latent_cached = True file_item.is_latent_cached = True
i += 1 i += 1
# flush every 100
# if i % 100 == 0:
# flush()
# restore device state # restore device state
self.sd.restore_device_state() self.sd.restore_device_state()

View File

@@ -1123,6 +1123,10 @@ class BaseModel:
return latents return latents
def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
raise NotImplementedError("Audio encoding not implemented for this model.")
def decode_latents( def decode_latents(
self, self,
latents: torch.Tensor, latents: torch.Tensor,

View File

@@ -2551,6 +2551,10 @@ class StableDiffusion:
return latents return latents
def encode_audio(self, audio_data_list):
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
raise NotImplementedError("Audio encoding not implemented for this model.")
def decode_latents( def decode_latents(
self, self,
latents: torch.Tensor, latents: torch.Tensor,

View File

@@ -1 +1 @@
VERSION = "0.7.18" VERSION = "0.7.19"