mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Do caching of latents, first frame and audio when caching latents for LTX2
This commit is contained in:
@@ -456,8 +456,6 @@ class ImageProcessingDTOMixin:
|
||||
transform: Union[None, transforms.Compose],
|
||||
only_load_latents=False
|
||||
):
|
||||
if self.is_latent_cached:
|
||||
raise Exception('Latent caching not supported for videos')
|
||||
|
||||
if self.augments is not None and len(self.augments) > 0:
|
||||
raise Exception('Augments not supported for videos')
|
||||
@@ -727,9 +725,6 @@ class ImageProcessingDTOMixin:
|
||||
transform: Union[None, transforms.Compose],
|
||||
only_load_latents=False
|
||||
):
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
# handle get_prompt_embedding
|
||||
if self.is_text_embedding_cached:
|
||||
self.load_prompt_embedding()
|
||||
@@ -747,6 +742,9 @@ class ImageProcessingDTOMixin:
|
||||
if self.has_unconditional:
|
||||
self.load_unconditional_image()
|
||||
return
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
try:
|
||||
img = Image.open(self.path)
|
||||
img = exif_transpose(img)
|
||||
@@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin:
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._encoded_latent: Union[torch.Tensor, None] = None
|
||||
self._cached_first_frame_latent: Union[torch.Tensor, None] = None
|
||||
self._cached_audio_latent: Union[torch.Tensor, None] = None
|
||||
self._latent_path: Union[str, None] = None
|
||||
self.is_latent_cached = False
|
||||
self.is_caching_to_disk = False
|
||||
@@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin:
|
||||
item["flip_y"] = True
|
||||
if self.dataset_config.num_frames > 1:
|
||||
item["num_frames"] = self.dataset_config.num_frames
|
||||
if self.dataset_config.do_i2v:
|
||||
item["do_i2v"] = True
|
||||
if self.dataset_config.do_audio:
|
||||
item["do_audio"] = True
|
||||
if self.dataset_config.audio_normalize:
|
||||
item["audio_normalize"] = True
|
||||
if self.dataset_config.audio_preserve_pitch:
|
||||
item["audio_preserve_pitch"] = True
|
||||
return item
|
||||
|
||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||
@@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin:
|
||||
if not self.is_caching_to_memory:
|
||||
# we are caching on disk, don't save in memory
|
||||
self._encoded_latent = None
|
||||
self._cached_first_frame_latent = None
|
||||
self._cached_audio_latent = None
|
||||
else:
|
||||
# move it back to cpu
|
||||
self._encoded_latent = self._encoded_latent.to('cpu')
|
||||
if self._cached_first_frame_latent is not None:
|
||||
self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu')
|
||||
if self._cached_audio_latent is not None:
|
||||
self._cached_audio_latent = self._cached_audio_latent.to('cpu')
|
||||
|
||||
def get_latent(self, device=None):
|
||||
if not self.is_latent_cached:
|
||||
@@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin:
|
||||
device='cpu'
|
||||
)
|
||||
self._encoded_latent = state_dict['latent']
|
||||
if 'first_frame_latent' in state_dict:
|
||||
self._cached_first_frame_latent = state_dict['first_frame_latent']
|
||||
if 'audio_latent' in state_dict:
|
||||
self._cached_audio_latent = state_dict['audio_latent']
|
||||
return self._encoded_latent
|
||||
|
||||
|
||||
@@ -1795,8 +1813,6 @@ class LatentCachingMixin:
|
||||
self.latent_cache = {}
|
||||
|
||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||
if self.dataset_config.num_frames > 1:
|
||||
raise Exception("Error: caching latents is not supported for multi-frame datasets")
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
@@ -1839,25 +1855,50 @@ class LatentCachingMixin:
|
||||
# load it into memory
|
||||
state_dict = load_file(latent_path, device='cpu')
|
||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
if 'first_frame_latent' in state_dict:
|
||||
file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
if 'audio_latent' in state_dict:
|
||||
file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||
else:
|
||||
# not saved to disk, calculate
|
||||
# load the image first
|
||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||
dtype = self.sd.torch_dtype
|
||||
device = self.sd.device_torch
|
||||
state_dict = OrderedDict()
|
||||
first_frame_latent = None
|
||||
audio_latent = None
|
||||
# add batch dimension
|
||||
try:
|
||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||
if to_disk:
|
||||
state_dict['latent'] = latent.clone().detach().cpu()
|
||||
except Exception as e:
|
||||
print_acc(f"Error processing image: {file_item.path}")
|
||||
print_acc(f"Error: {str(e)}")
|
||||
raise e
|
||||
# do first frame
|
||||
if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v:
|
||||
frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||
if len(frames.shape) == 4:
|
||||
first_frames = frames
|
||||
elif len(frames.shape) == 5:
|
||||
first_frames = frames[:, 0]
|
||||
else:
|
||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||
first_frame_latent = self.sd.encode_images(first_frames).squeeze(0)
|
||||
if to_disk:
|
||||
state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu()
|
||||
|
||||
# audio
|
||||
if file_item.audio_data is not None:
|
||||
audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0)
|
||||
if to_disk:
|
||||
state_dict['audio_latent'] = audio_latent.clone().detach().cpu()
|
||||
|
||||
# save_latent
|
||||
if to_disk:
|
||||
state_dict = OrderedDict([
|
||||
('latent', latent.clone().detach().cpu()),
|
||||
])
|
||||
# metadata
|
||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||
@@ -1866,17 +1907,18 @@ class LatentCachingMixin:
|
||||
if to_memory:
|
||||
# keep it in memory
|
||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
if first_frame_latent is not None:
|
||||
file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
if audio_latent is not None:
|
||||
file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||
|
||||
del imgs
|
||||
del latent
|
||||
del file_item.tensor
|
||||
file_item.cleanup()
|
||||
|
||||
# flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
i += 1
|
||||
# flush every 100
|
||||
# if i % 100 == 0:
|
||||
# flush()
|
||||
|
||||
# restore device state
|
||||
self.sd.restore_device_state()
|
||||
|
||||
Reference in New Issue
Block a user