mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Do caching of latents, first frame and audio when caching latents for LTX2
This commit is contained in:
@@ -1,24 +1,31 @@
|
||||
import os
|
||||
import weakref
|
||||
from _weakref import ReferenceType
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import cv2
|
||||
import torch
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.basic import get_quick_signature_string
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
|
||||
from toolkit.dataloader_mixins import (
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
LatentCachingFileItemDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
UnconditionalFileItemDTOMixin,
|
||||
ClipImageFileItemDTOMixin,
|
||||
InpaintControlFileItemDTOMixin,
|
||||
TextEmbeddingFileItemDTOMixin,
|
||||
)
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
printed_messages = []
|
||||
|
||||
@@ -45,54 +52,62 @@ class FileItemDTO(
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.path = kwargs.get('path', '')
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.path = kwargs.get("path", "")
|
||||
self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None)
|
||||
self.is_video = self.dataset_config.num_frames > 1
|
||||
size_database = kwargs.get('size_database', {})
|
||||
dataset_root = kwargs.get('dataset_root', None)
|
||||
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
|
||||
size_database = kwargs.get("size_database", {})
|
||||
dataset_root = kwargs.get("dataset_root", None)
|
||||
self.encode_control_in_text_embeddings = kwargs.get(
|
||||
"encode_control_in_text_embeddings", False
|
||||
)
|
||||
if dataset_root is not None:
|
||||
# remove dataset root from path
|
||||
file_key = self.path.replace(dataset_root, '')
|
||||
file_key = self.path.replace(dataset_root, "")
|
||||
else:
|
||||
file_key = os.path.basename(self.path)
|
||||
|
||||
|
||||
file_signature = get_quick_signature_string(self.path)
|
||||
if file_signature is None:
|
||||
raise Exception("Error: Could not get file signature for {self.path}")
|
||||
|
||||
|
||||
use_db_entry = False
|
||||
if file_key in size_database:
|
||||
db_entry = size_database[file_key]
|
||||
if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature:
|
||||
if (
|
||||
db_entry is not None
|
||||
and len(db_entry) >= 3
|
||||
and db_entry[2] == file_signature
|
||||
):
|
||||
use_db_entry = True
|
||||
|
||||
|
||||
if use_db_entry:
|
||||
w, h, _ = size_database[file_key]
|
||||
elif self.is_video:
|
||||
# Open the video file
|
||||
video = cv2.VideoCapture(self.path)
|
||||
|
||||
|
||||
# Check if video opened successfully
|
||||
if not video.isOpened():
|
||||
raise Exception(f"Error: Could not open video file {self.path}")
|
||||
|
||||
|
||||
# Get width and height
|
||||
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
w, h = width, height
|
||||
|
||||
|
||||
# Release the video capture object immediately
|
||||
video.release()
|
||||
size_database[file_key] = (width, height, file_signature)
|
||||
else:
|
||||
if self.dataset_config.fast_image_size:
|
||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
|
||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method by default.
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
print_once(
|
||||
f"Warning: Some images in the dataset cannot be fast read. "
|
||||
+ f"This process is faster for png, jpeg"
|
||||
)
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
w, h = img.size
|
||||
else:
|
||||
@@ -101,21 +116,25 @@ class FileItemDTO(
|
||||
size_database[file_key] = (w, h, file_signature)
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
||||
self.dataloader_transforms = kwargs.get("dataloader_transforms", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||
self.raw_caption: str = kwargs.get("raw_caption", None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
|
||||
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
|
||||
self.scale_to_width: int = kwargs.get(
|
||||
"scale_to_width", int(self.width * self.dataset_config.scale)
|
||||
)
|
||||
self.scale_to_height: int = kwargs.get(
|
||||
"scale_to_height", int(self.height * self.dataset_config.scale)
|
||||
)
|
||||
# crop values are from scaled size
|
||||
self.crop_x: int = kwargs.get('crop_x', 0)
|
||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
self.crop_x: int = kwargs.get("crop_x", 0)
|
||||
self.crop_y: int = kwargs.get("crop_y", 0)
|
||||
self.crop_width: int = kwargs.get("crop_width", self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get("crop_height", self.scale_to_height)
|
||||
self.flip_x: bool = kwargs.get("flip_x", False)
|
||||
self.flip_y: bool = kwargs.get("flip_x", False)
|
||||
self.augments: List[str] = self.dataset_config.augments
|
||||
self.loss_multiplier: float = self.dataset_config.loss_multiplier
|
||||
|
||||
@@ -142,9 +161,8 @@ class FileItemDTO(
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -156,22 +174,64 @@ class DataLoaderBatchDTO:
|
||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||
self.clip_image_embeds: Union[List[dict], None] = None
|
||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
|
||||
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
|
||||
self.sigmas: Union[torch.Tensor, None] = (
|
||||
None # can be added elseware and passed along training code
|
||||
)
|
||||
self.extra_values: Union[torch.Tensor, None] = (
|
||||
torch.tensor([x.extra_values for x in self.file_items])
|
||||
if len(self.file_items[0].extra_values) > 0
|
||||
else None
|
||||
)
|
||||
self.audio_data: Union[List, None] = (
|
||||
[x.audio_data for x in self.file_items]
|
||||
if self.file_items[0].audio_data is not None
|
||||
else None
|
||||
)
|
||||
self.audio_tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
self.first_frame_latents: Union[torch.Tensor, None] = None
|
||||
self.audio_latents: Union[torch.Tensor, None] = None
|
||||
|
||||
# just for holding noise and preds during training
|
||||
self.audio_target: Union[torch.Tensor, None] = None
|
||||
self.audio_pred: Union[torch.Tensor, None] = None
|
||||
|
||||
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
self.tensor: torch.Tensor = torch.cat(
|
||||
[x.tensor.unsqueeze(0) for x in self.file_items]
|
||||
)
|
||||
# if we have encoded latents, we concatenate them
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
# this get_latent call with trigger loading all cached items from the disk
|
||||
self.latents = torch.cat(
|
||||
[x.get_latent().unsqueeze(0) for x in self.file_items]
|
||||
)
|
||||
if any(
|
||||
[x._cached_first_frame_latent is not None for x in self.file_items]
|
||||
):
|
||||
self.first_frame_latents = torch.cat(
|
||||
[
|
||||
x._cached_first_frame_latent.unsqueeze(0)
|
||||
if x._cached_first_frame_latent is not None
|
||||
else torch.zeros_like(
|
||||
self.file_items[0]._cached_first_frame_latent
|
||||
).unsqueeze(0)
|
||||
for x in self.file_items
|
||||
]
|
||||
)
|
||||
if any([x._cached_audio_latent is not None for x in self.file_items]):
|
||||
self.audio_latents = torch.cat(
|
||||
[
|
||||
x._cached_audio_latent.unsqueeze(0)
|
||||
if x._cached_audio_latent is not None
|
||||
else torch.zeros_like(
|
||||
self.file_items[0]._cached_audio_latent
|
||||
).unsqueeze(0)
|
||||
for x in self.file_items
|
||||
]
|
||||
)
|
||||
|
||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||
# if self.file_items[0].control_tensor is not None:
|
||||
# if any have a control tensor, we concatenate them
|
||||
@@ -188,8 +248,10 @@ class DataLoaderBatchDTO:
|
||||
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||
else:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
self.control_tensor = torch.cat(
|
||||
[x.unsqueeze(0) for x in control_tensors]
|
||||
)
|
||||
|
||||
# handle control tensor list
|
||||
if any([x.control_tensor_list is not None for x in self.file_items]):
|
||||
self.control_tensor_list = []
|
||||
@@ -197,9 +259,10 @@ class DataLoaderBatchDTO:
|
||||
if x.control_tensor_list is not None:
|
||||
self.control_tensor_list.append(x.control_tensor_list)
|
||||
else:
|
||||
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
|
||||
|
||||
|
||||
raise Exception(
|
||||
f"Could not find control tensors for all file items, missing for {x.path}"
|
||||
)
|
||||
|
||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
@@ -214,9 +277,13 @@ class DataLoaderBatchDTO:
|
||||
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
|
||||
else:
|
||||
inpaint_tensors.append(x.inpaint_tensor)
|
||||
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
|
||||
self.inpaint_tensor = torch.cat(
|
||||
[x.unsqueeze(0) for x in inpaint_tensors]
|
||||
)
|
||||
|
||||
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
|
||||
self.loss_multiplier_list: List[float] = [
|
||||
x.loss_multiplier for x in self.file_items
|
||||
]
|
||||
|
||||
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
@@ -228,10 +295,14 @@ class DataLoaderBatchDTO:
|
||||
clip_image_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_tensor is None:
|
||||
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
|
||||
clip_image_tensors.append(
|
||||
torch.zeros_like(base_clip_image_tensor)
|
||||
)
|
||||
else:
|
||||
clip_image_tensors.append(x.clip_image_tensor)
|
||||
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
|
||||
self.clip_image_tensor = torch.cat(
|
||||
[x.unsqueeze(0) for x in clip_image_tensors]
|
||||
)
|
||||
|
||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
@@ -259,10 +330,14 @@ class DataLoaderBatchDTO:
|
||||
unaugmented_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is None:
|
||||
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
|
||||
unaugmented_tensor.append(
|
||||
torch.zeros_like(base_unaugmented_tensor)
|
||||
)
|
||||
else:
|
||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||
self.unaugmented_tensor = torch.cat(
|
||||
[x.unsqueeze(0) for x in unaugmented_tensor]
|
||||
)
|
||||
|
||||
# add unconditional tensors
|
||||
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
||||
@@ -275,10 +350,14 @@ class DataLoaderBatchDTO:
|
||||
unconditional_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unconditional_tensor is None:
|
||||
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
|
||||
unconditional_tensor.append(
|
||||
torch.zeros_like(base_unconditional_tensor)
|
||||
)
|
||||
else:
|
||||
unconditional_tensor.append(x.unconditional_tensor)
|
||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
||||
self.unconditional_tensor = torch.cat(
|
||||
[x.unsqueeze(0) for x in unconditional_tensor]
|
||||
)
|
||||
|
||||
if any([x.clip_image_embeds is not None for x in self.file_items]):
|
||||
self.clip_image_embeds = []
|
||||
@@ -288,14 +367,20 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
raise Exception("clip_image_embeds is None for some file items")
|
||||
|
||||
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
|
||||
if any(
|
||||
[x.clip_image_embeds_unconditional is not None for x in self.file_items]
|
||||
):
|
||||
self.clip_image_embeds_unconditional = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_embeds_unconditional is not None:
|
||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
||||
self.clip_image_embeds_unconditional.append(
|
||||
x.clip_image_embeds_unconditional
|
||||
)
|
||||
else:
|
||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
||||
|
||||
raise Exception(
|
||||
"clip_image_embeds_unconditional is None for some file items"
|
||||
)
|
||||
|
||||
if any([x.prompt_embeds is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_prompt_embeds = None
|
||||
@@ -315,7 +400,7 @@ class DataLoaderBatchDTO:
|
||||
y.text_embeds = [y.text_embeds]
|
||||
prompt_embeds_list.append(y)
|
||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||
|
||||
|
||||
if any([x.audio_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_audio_tensor = None
|
||||
@@ -330,7 +415,6 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
audio_tensors.append(x.audio_tensor)
|
||||
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -343,18 +427,12 @@ class DataLoaderBatchDTO:
|
||||
return [x.network_weight for x in self.file_items]
|
||||
|
||||
def get_caption_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
self, trigger=None, to_replace_list=None, add_if_not_present=True
|
||||
):
|
||||
return [x.caption for x in self.file_items]
|
||||
|
||||
def get_caption_short_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
self, trigger=None, to_replace_list=None, add_if_not_present=True
|
||||
):
|
||||
return [x.caption_short for x in self.file_items]
|
||||
|
||||
@@ -366,11 +444,13 @@ class DataLoaderBatchDTO:
|
||||
del self.audio_data
|
||||
del self.audio_target
|
||||
del self.audio_pred
|
||||
del self.first_frame_latents
|
||||
del self.audio_latents
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
|
||||
@property
|
||||
def dataset_config(self) -> 'DatasetConfig':
|
||||
def dataset_config(self) -> "DatasetConfig":
|
||||
if len(self.file_items) > 0:
|
||||
return self.file_items[0].dataset_config
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user