mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Do caching of latents, first frame and audio when caching latents for LTX2
This commit is contained in:
@@ -98,6 +98,41 @@ def blank_log_image_function(self, *args, **kwargs):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class ComboVae(torch.nn.Module):
|
||||||
|
"""Combines video and audio VAEs for joint encoding and decoding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: AutoencoderKLLTX2Video,
|
||||||
|
audio_vae: AutoencoderKLLTX2Audio,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.vae = vae
|
||||||
|
self.audio_vae = audio_vae
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vae.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.vae.dtype
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return self.vae.encode(*args, **kwargs)
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return self.vae.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(torch.nn.Module):
|
class AudioProcessor(torch.nn.Module):
|
||||||
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||||
|
|
||||||
@@ -445,7 +480,7 @@ class LTX2Model(BaseModel):
|
|||||||
flush()
|
flush()
|
||||||
|
|
||||||
# save it to the model class
|
# save it to the model class
|
||||||
self.vae = vae
|
self.vae = ComboVae(pipe.vae, pipe.audio_vae)
|
||||||
self.text_encoder = text_encoder # list of text encoders
|
self.text_encoder = text_encoder # list of text encoders
|
||||||
self.tokenizer = tokenizer # list of tokenizers
|
self.tokenizer = tokenizer # list of tokenizers
|
||||||
self.model = pipe.transformer
|
self.model = pipe.transformer
|
||||||
@@ -467,10 +502,10 @@ class LTX2Model(BaseModel):
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = self.vae_torch_dtype
|
dtype = self.vae_torch_dtype
|
||||||
|
|
||||||
if self.vae.device == torch.device("cpu"):
|
if self.pipeline.vae.device == torch.device("cpu"):
|
||||||
self.vae.to(device)
|
self.pipeline.vae.to(device)
|
||||||
self.vae.eval()
|
self.pipeline.vae.eval()
|
||||||
self.vae.requires_grad_(False)
|
self.pipeline.vae.requires_grad_(False)
|
||||||
|
|
||||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||||
|
|
||||||
@@ -489,7 +524,7 @@ class LTX2Model(BaseModel):
|
|||||||
# Stack to (B, C, T, H, W)
|
# Stack to (B, C, T, H, W)
|
||||||
images = torch.stack(norm_images)
|
images = torch.stack(norm_images)
|
||||||
|
|
||||||
latents = self.vae.encode(images).latent_dist.mode()
|
latents = self.pipeline.vae.encode(images).latent_dist.mode()
|
||||||
|
|
||||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
@@ -571,7 +606,9 @@ class LTX2Model(BaseModel):
|
|||||||
if gen_config.ctrl_img is not None:
|
if gen_config.ctrl_img is not None:
|
||||||
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
control_img = Image.open(gen_config.ctrl_img).convert("RGB")
|
||||||
# resize the control image
|
# resize the control image
|
||||||
control_img = control_img.resize((gen_config.width, gen_config.height), Image.LANCZOS)
|
control_img = control_img.resize(
|
||||||
|
(gen_config.width, gen_config.height), Image.LANCZOS
|
||||||
|
)
|
||||||
# add the control image to the extra dict
|
# add the control image to the extra dict
|
||||||
extra["image"] = control_img
|
extra["image"] = control_img
|
||||||
|
|
||||||
@@ -642,7 +679,8 @@ class LTX2Model(BaseModel):
|
|||||||
img = video[0]
|
img = video[0]
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def encode_audio(self, batch: "DataLoaderBatchDTO"):
|
def encode_audio(self, audio_data_list):
|
||||||
|
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
|
||||||
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
if self.pipeline.audio_vae.device == torch.device("cpu"):
|
||||||
self.pipeline.audio_vae.to(self.device_torch)
|
self.pipeline.audio_vae.to(self.device_torch)
|
||||||
|
|
||||||
@@ -650,7 +688,7 @@ class LTX2Model(BaseModel):
|
|||||||
audio_num_frames = None
|
audio_num_frames = None
|
||||||
|
|
||||||
# do them seperatly for now
|
# do them seperatly for now
|
||||||
for audio_data in batch.audio_data:
|
for audio_data in audio_data_list:
|
||||||
waveform = audio_data["waveform"].to(
|
waveform = audio_data["waveform"].to(
|
||||||
device=self.device_torch, dtype=torch.float32
|
device=self.device_torch, dtype=torch.float32
|
||||||
)
|
)
|
||||||
@@ -665,14 +703,18 @@ class LTX2Model(BaseModel):
|
|||||||
waveform = waveform.repeat(1, 2, 1)
|
waveform = waveform.repeat(1, 2, 1)
|
||||||
|
|
||||||
# Convert waveform to mel spectrogram using AudioProcessor
|
# Convert waveform to mel spectrogram using AudioProcessor
|
||||||
mel_spectrogram = self.audio_processor.waveform_to_mel(waveform, waveform_sample_rate=sample_rate)
|
mel_spectrogram = self.audio_processor.waveform_to_mel(
|
||||||
|
waveform, waveform_sample_rate=sample_rate
|
||||||
|
)
|
||||||
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
mel_spectrogram = mel_spectrogram.to(dtype=self.torch_dtype)
|
||||||
|
|
||||||
# Encode mel spectrogram to latents
|
# Encode mel spectrogram to latents
|
||||||
latents = self.pipeline.audio_vae.encode(mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)).latent_dist.mode()
|
latents = self.pipeline.audio_vae.encode(
|
||||||
|
mel_spectrogram.to(self.device_torch, dtype=self.torch_dtype)
|
||||||
|
).latent_dist.mode()
|
||||||
|
|
||||||
if audio_num_frames is None:
|
if audio_num_frames is None:
|
||||||
audio_num_frames = latents.shape[2] #(latents is [B, C, T, F])
|
audio_num_frames = latents.shape[2] # (latents is [B, C, T, F])
|
||||||
|
|
||||||
packed_latents = self.pipeline._pack_audio_latents(
|
packed_latents = self.pipeline._pack_audio_latents(
|
||||||
latents,
|
latents,
|
||||||
@@ -688,7 +730,7 @@ class LTX2Model(BaseModel):
|
|||||||
latents_mean = self.pipeline.audio_vae.latents_mean
|
latents_mean = self.pipeline.audio_vae.latents_mean
|
||||||
latents_std = self.pipeline.audio_vae.latents_std
|
latents_std = self.pipeline.audio_vae.latents_std
|
||||||
output_tensor = (output_tensor - latents_mean) / latents_std
|
output_tensor = (output_tensor - latents_mean) / latents_std
|
||||||
return output_tensor, audio_num_frames
|
return output_tensor
|
||||||
|
|
||||||
def get_noise_prediction(
|
def get_noise_prediction(
|
||||||
self,
|
self,
|
||||||
@@ -710,6 +752,13 @@ class LTX2Model(BaseModel):
|
|||||||
|
|
||||||
# i2v from first frame
|
# i2v from first frame
|
||||||
if batch.dataset_config.do_i2v:
|
if batch.dataset_config.do_i2v:
|
||||||
|
# check to see if we had it cached
|
||||||
|
if batch.first_frame_latents is not None:
|
||||||
|
init_latents = batch.first_frame_latents.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# extract the first frame and encode it
|
||||||
# videos come in (bs, num_frames, channels, height, width)
|
# videos come in (bs, num_frames, channels, height, width)
|
||||||
# images come in (bs, channels, height, width)
|
# images come in (bs, channels, height, width)
|
||||||
frames = batch.tensor
|
frames = batch.tensor
|
||||||
@@ -720,11 +769,23 @@ class LTX2Model(BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown frame shape {frames.shape}")
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||||
# first frame doesnt have time dim, add it back
|
# first frame doesnt have time dim, add it back
|
||||||
init_latents = self.encode_images(first_frames, device=self.device_torch, dtype=self.torch_dtype)
|
init_latents = self.encode_images(
|
||||||
|
first_frames, device=self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand the latents to match video frames
|
||||||
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
init_latents = init_latents.repeat(1, 1, latent_num_frames, 1, 1)
|
||||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
mask_shape = (
|
||||||
|
batch_size,
|
||||||
|
1,
|
||||||
|
latent_num_frames,
|
||||||
|
latent_height,
|
||||||
|
latent_width,
|
||||||
|
)
|
||||||
# First condition is image latents and those should be kept clean.
|
# First condition is image latents and those should be kept clean.
|
||||||
conditioning_mask = torch.zeros(mask_shape, device=self.device_torch, dtype=self.torch_dtype)
|
conditioning_mask = torch.zeros(
|
||||||
|
mask_shape, device=self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
conditioning_mask[:, :, 0] = 1.0
|
conditioning_mask[:, :, 0] = 1.0
|
||||||
|
|
||||||
# use conditioning mask to replace latents
|
# use conditioning mask to replace latents
|
||||||
@@ -746,10 +807,18 @@ class LTX2Model(BaseModel):
|
|||||||
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
patch_size_t=self.pipeline.transformer_temporal_patch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.audio_tensor is not None:
|
if batch.audio_latents is not None or batch.audio_tensor is not None:
|
||||||
|
if batch.audio_latents is not None:
|
||||||
|
# we have audio latents cached
|
||||||
|
raw_audio_latents = batch.audio_latents.to(
|
||||||
|
self.device_torch, dtype=self.torch_dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# we have audio waveforms to encode
|
||||||
# use audio from the batch if available
|
# use audio from the batch if available
|
||||||
#(1, 190, 128)
|
raw_audio_latents = self.encode_audio(batch.audio_data)
|
||||||
raw_audio_latents, audio_num_frames = self.encode_audio(batch)
|
|
||||||
|
audio_num_frames = raw_audio_latents.shape[1]
|
||||||
# add the audio targets to the batch for loss calculation later
|
# add the audio targets to the batch for loss calculation later
|
||||||
audio_noise = torch.randn_like(raw_audio_latents)
|
audio_noise = torch.randn_like(raw_audio_latents)
|
||||||
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
batch.audio_target = (audio_noise - raw_audio_latents).detach()
|
||||||
@@ -758,7 +827,6 @@ class LTX2Model(BaseModel):
|
|||||||
audio_noise,
|
audio_noise,
|
||||||
timestep,
|
timestep,
|
||||||
).to(self.device_torch, dtype=self.torch_dtype)
|
).to(self.device_torch, dtype=self.torch_dtype)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# no audio
|
# no audio
|
||||||
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
num_mel_bins = self.pipeline.audio_vae.config.mel_bins
|
||||||
@@ -766,7 +834,6 @@ class LTX2Model(BaseModel):
|
|||||||
num_channels_latents_audio = (
|
num_channels_latents_audio = (
|
||||||
self.pipeline.audio_vae.config.latent_channels
|
self.pipeline.audio_vae.config.latent_channels
|
||||||
)
|
)
|
||||||
# audio latents are (1, 126, 128), audio_num_frames = 126
|
|
||||||
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
audio_latents, audio_num_frames = self.pipeline.prepare_audio_latents(
|
||||||
batch_size,
|
batch_size,
|
||||||
num_channels_latents=num_channels_latents_audio,
|
num_channels_latents=num_channels_latents_audio,
|
||||||
|
|||||||
@@ -1,24 +1,31 @@
|
|||||||
import os
|
import os
|
||||||
import weakref
|
|
||||||
from _weakref import ReferenceType
|
|
||||||
from typing import TYPE_CHECKING, List, Union
|
from typing import TYPE_CHECKING, List, Union
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import random
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.ImageOps import exif_transpose
|
from PIL.ImageOps import exif_transpose
|
||||||
|
|
||||||
from toolkit import image_utils
|
from toolkit import image_utils
|
||||||
from toolkit.basic import get_quick_signature_string
|
from toolkit.basic import get_quick_signature_string
|
||||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
from toolkit.dataloader_mixins import (
|
||||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
CaptionProcessingDTOMixin,
|
||||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin, TextEmbeddingFileItemDTOMixin
|
ImageProcessingDTOMixin,
|
||||||
|
LatentCachingFileItemDTOMixin,
|
||||||
|
ControlFileItemDTOMixin,
|
||||||
|
ArgBreakMixin,
|
||||||
|
PoiFileItemDTOMixin,
|
||||||
|
MaskFileItemDTOMixin,
|
||||||
|
AugmentationFileItemDTOMixin,
|
||||||
|
UnconditionalFileItemDTOMixin,
|
||||||
|
ClipImageFileItemDTOMixin,
|
||||||
|
InpaintControlFileItemDTOMixin,
|
||||||
|
TextEmbeddingFileItemDTOMixin,
|
||||||
|
)
|
||||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.config_modules import DatasetConfig
|
from toolkit.config_modules import DatasetConfig
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
|
||||||
|
|
||||||
printed_messages = []
|
printed_messages = []
|
||||||
|
|
||||||
@@ -45,15 +52,17 @@ class FileItemDTO(
|
|||||||
ArgBreakMixin,
|
ArgBreakMixin,
|
||||||
):
|
):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.path = kwargs.get('path', '')
|
self.path = kwargs.get("path", "")
|
||||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
self.dataset_config: "DatasetConfig" = kwargs.get("dataset_config", None)
|
||||||
self.is_video = self.dataset_config.num_frames > 1
|
self.is_video = self.dataset_config.num_frames > 1
|
||||||
size_database = kwargs.get('size_database', {})
|
size_database = kwargs.get("size_database", {})
|
||||||
dataset_root = kwargs.get('dataset_root', None)
|
dataset_root = kwargs.get("dataset_root", None)
|
||||||
self.encode_control_in_text_embeddings = kwargs.get('encode_control_in_text_embeddings', False)
|
self.encode_control_in_text_embeddings = kwargs.get(
|
||||||
|
"encode_control_in_text_embeddings", False
|
||||||
|
)
|
||||||
if dataset_root is not None:
|
if dataset_root is not None:
|
||||||
# remove dataset root from path
|
# remove dataset root from path
|
||||||
file_key = self.path.replace(dataset_root, '')
|
file_key = self.path.replace(dataset_root, "")
|
||||||
else:
|
else:
|
||||||
file_key = os.path.basename(self.path)
|
file_key = os.path.basename(self.path)
|
||||||
|
|
||||||
@@ -64,7 +73,11 @@ class FileItemDTO(
|
|||||||
use_db_entry = False
|
use_db_entry = False
|
||||||
if file_key in size_database:
|
if file_key in size_database:
|
||||||
db_entry = size_database[file_key]
|
db_entry = size_database[file_key]
|
||||||
if db_entry is not None and len(db_entry) >= 3 and db_entry[2] == file_signature:
|
if (
|
||||||
|
db_entry is not None
|
||||||
|
and len(db_entry) >= 3
|
||||||
|
and db_entry[2] == file_signature
|
||||||
|
):
|
||||||
use_db_entry = True
|
use_db_entry = True
|
||||||
|
|
||||||
if use_db_entry:
|
if use_db_entry:
|
||||||
@@ -91,8 +104,10 @@ class FileItemDTO(
|
|||||||
try:
|
try:
|
||||||
w, h = image_utils.get_image_size(self.path)
|
w, h = image_utils.get_image_size(self.path)
|
||||||
except image_utils.UnknownImageFormat:
|
except image_utils.UnknownImageFormat:
|
||||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
print_once(
|
||||||
f'This process is faster for png, jpeg')
|
f"Warning: Some images in the dataset cannot be fast read. "
|
||||||
|
+ f"This process is faster for png, jpeg"
|
||||||
|
)
|
||||||
img = exif_transpose(Image.open(self.path))
|
img = exif_transpose(Image.open(self.path))
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
else:
|
else:
|
||||||
@@ -101,21 +116,25 @@ class FileItemDTO(
|
|||||||
size_database[file_key] = (w, h, file_signature)
|
size_database[file_key] = (w, h, file_signature)
|
||||||
self.width: int = w
|
self.width: int = w
|
||||||
self.height: int = h
|
self.height: int = h
|
||||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
self.dataloader_transforms = kwargs.get("dataloader_transforms", None)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
self.raw_caption: str = kwargs.get("raw_caption", None)
|
||||||
# we scale first, then crop
|
# we scale first, then crop
|
||||||
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
|
self.scale_to_width: int = kwargs.get(
|
||||||
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
|
"scale_to_width", int(self.width * self.dataset_config.scale)
|
||||||
|
)
|
||||||
|
self.scale_to_height: int = kwargs.get(
|
||||||
|
"scale_to_height", int(self.height * self.dataset_config.scale)
|
||||||
|
)
|
||||||
# crop values are from scaled size
|
# crop values are from scaled size
|
||||||
self.crop_x: int = kwargs.get('crop_x', 0)
|
self.crop_x: int = kwargs.get("crop_x", 0)
|
||||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
self.crop_y: int = kwargs.get("crop_y", 0)
|
||||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
self.crop_width: int = kwargs.get("crop_width", self.scale_to_width)
|
||||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
self.crop_height: int = kwargs.get("crop_height", self.scale_to_height)
|
||||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
self.flip_x: bool = kwargs.get("flip_x", False)
|
||||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
self.flip_y: bool = kwargs.get("flip_x", False)
|
||||||
self.augments: List[str] = self.dataset_config.augments
|
self.augments: List[str] = self.dataset_config.augments
|
||||||
self.loss_multiplier: float = self.dataset_config.loss_multiplier
|
self.loss_multiplier: float = self.dataset_config.loss_multiplier
|
||||||
|
|
||||||
@@ -142,9 +161,8 @@ class FileItemDTO(
|
|||||||
class DataLoaderBatchDTO:
|
class DataLoaderBatchDTO:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
self.file_items: List["FileItemDTO"] = kwargs.get("file_items", None)
|
||||||
is_latents_cached = self.file_items[0].is_latent_cached
|
is_latents_cached = self.file_items[0].is_latent_cached
|
||||||
is_text_embedding_cached = self.file_items[0].is_text_embedding_cached
|
|
||||||
self.tensor: Union[torch.Tensor, None] = None
|
self.tensor: Union[torch.Tensor, None] = None
|
||||||
self.latents: Union[torch.Tensor, None] = None
|
self.latents: Union[torch.Tensor, None] = None
|
||||||
self.control_tensor: Union[torch.Tensor, None] = None
|
self.control_tensor: Union[torch.Tensor, None] = None
|
||||||
@@ -156,10 +174,22 @@ class DataLoaderBatchDTO:
|
|||||||
self.unconditional_latents: Union[torch.Tensor, None] = None
|
self.unconditional_latents: Union[torch.Tensor, None] = None
|
||||||
self.clip_image_embeds: Union[List[dict], None] = None
|
self.clip_image_embeds: Union[List[dict], None] = None
|
||||||
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
self.clip_image_embeds_unconditional: Union[List[dict], None] = None
|
||||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
self.sigmas: Union[torch.Tensor, None] = (
|
||||||
self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
|
None # can be added elseware and passed along training code
|
||||||
self.audio_data: Union[List, None] = [x.audio_data for x in self.file_items] if self.file_items[0].audio_data is not None else None
|
)
|
||||||
|
self.extra_values: Union[torch.Tensor, None] = (
|
||||||
|
torch.tensor([x.extra_values for x in self.file_items])
|
||||||
|
if len(self.file_items[0].extra_values) > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.audio_data: Union[List, None] = (
|
||||||
|
[x.audio_data for x in self.file_items]
|
||||||
|
if self.file_items[0].audio_data is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.audio_tensor: Union[torch.Tensor, None] = None
|
self.audio_tensor: Union[torch.Tensor, None] = None
|
||||||
|
self.first_frame_latents: Union[torch.Tensor, None] = None
|
||||||
|
self.audio_latents: Union[torch.Tensor, None] = None
|
||||||
|
|
||||||
# just for holding noise and preds during training
|
# just for holding noise and preds during training
|
||||||
self.audio_target: Union[torch.Tensor, None] = None
|
self.audio_target: Union[torch.Tensor, None] = None
|
||||||
@@ -167,11 +197,41 @@ class DataLoaderBatchDTO:
|
|||||||
|
|
||||||
if not is_latents_cached:
|
if not is_latents_cached:
|
||||||
# only return a tensor if latents are not cached
|
# only return a tensor if latents are not cached
|
||||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
self.tensor: torch.Tensor = torch.cat(
|
||||||
|
[x.tensor.unsqueeze(0) for x in self.file_items]
|
||||||
|
)
|
||||||
# if we have encoded latents, we concatenate them
|
# if we have encoded latents, we concatenate them
|
||||||
self.latents: Union[torch.Tensor, None] = None
|
self.latents: Union[torch.Tensor, None] = None
|
||||||
if is_latents_cached:
|
if is_latents_cached:
|
||||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
# this get_latent call with trigger loading all cached items from the disk
|
||||||
|
self.latents = torch.cat(
|
||||||
|
[x.get_latent().unsqueeze(0) for x in self.file_items]
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
[x._cached_first_frame_latent is not None for x in self.file_items]
|
||||||
|
):
|
||||||
|
self.first_frame_latents = torch.cat(
|
||||||
|
[
|
||||||
|
x._cached_first_frame_latent.unsqueeze(0)
|
||||||
|
if x._cached_first_frame_latent is not None
|
||||||
|
else torch.zeros_like(
|
||||||
|
self.file_items[0]._cached_first_frame_latent
|
||||||
|
).unsqueeze(0)
|
||||||
|
for x in self.file_items
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if any([x._cached_audio_latent is not None for x in self.file_items]):
|
||||||
|
self.audio_latents = torch.cat(
|
||||||
|
[
|
||||||
|
x._cached_audio_latent.unsqueeze(0)
|
||||||
|
if x._cached_audio_latent is not None
|
||||||
|
else torch.zeros_like(
|
||||||
|
self.file_items[0]._cached_audio_latent
|
||||||
|
).unsqueeze(0)
|
||||||
|
for x in self.file_items
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
self.prompt_embeds: Union[PromptEmbeds, None] = None
|
||||||
# if self.file_items[0].control_tensor is not None:
|
# if self.file_items[0].control_tensor is not None:
|
||||||
# if any have a control tensor, we concatenate them
|
# if any have a control tensor, we concatenate them
|
||||||
@@ -188,7 +248,9 @@ class DataLoaderBatchDTO:
|
|||||||
control_tensors.append(torch.zeros_like(base_control_tensor))
|
control_tensors.append(torch.zeros_like(base_control_tensor))
|
||||||
else:
|
else:
|
||||||
control_tensors.append(x.control_tensor)
|
control_tensors.append(x.control_tensor)
|
||||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
self.control_tensor = torch.cat(
|
||||||
|
[x.unsqueeze(0) for x in control_tensors]
|
||||||
|
)
|
||||||
|
|
||||||
# handle control tensor list
|
# handle control tensor list
|
||||||
if any([x.control_tensor_list is not None for x in self.file_items]):
|
if any([x.control_tensor_list is not None for x in self.file_items]):
|
||||||
@@ -197,8 +259,9 @@ class DataLoaderBatchDTO:
|
|||||||
if x.control_tensor_list is not None:
|
if x.control_tensor_list is not None:
|
||||||
self.control_tensor_list.append(x.control_tensor_list)
|
self.control_tensor_list.append(x.control_tensor_list)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Could not find control tensors for all file items, missing for {x.path}")
|
raise Exception(
|
||||||
|
f"Could not find control tensors for all file items, missing for {x.path}"
|
||||||
|
)
|
||||||
|
|
||||||
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
self.inpaint_tensor: Union[torch.Tensor, None] = None
|
||||||
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
if any([x.inpaint_tensor is not None for x in self.file_items]):
|
||||||
@@ -214,9 +277,13 @@ class DataLoaderBatchDTO:
|
|||||||
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
|
inpaint_tensors.append(torch.zeros_like(base_inpaint_tensor))
|
||||||
else:
|
else:
|
||||||
inpaint_tensors.append(x.inpaint_tensor)
|
inpaint_tensors.append(x.inpaint_tensor)
|
||||||
self.inpaint_tensor = torch.cat([x.unsqueeze(0) for x in inpaint_tensors])
|
self.inpaint_tensor = torch.cat(
|
||||||
|
[x.unsqueeze(0) for x in inpaint_tensors]
|
||||||
|
)
|
||||||
|
|
||||||
self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
|
self.loss_multiplier_list: List[float] = [
|
||||||
|
x.loss_multiplier for x in self.file_items
|
||||||
|
]
|
||||||
|
|
||||||
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
||||||
# find one to use as a base
|
# find one to use as a base
|
||||||
@@ -228,10 +295,14 @@ class DataLoaderBatchDTO:
|
|||||||
clip_image_tensors = []
|
clip_image_tensors = []
|
||||||
for x in self.file_items:
|
for x in self.file_items:
|
||||||
if x.clip_image_tensor is None:
|
if x.clip_image_tensor is None:
|
||||||
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
|
clip_image_tensors.append(
|
||||||
|
torch.zeros_like(base_clip_image_tensor)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
clip_image_tensors.append(x.clip_image_tensor)
|
clip_image_tensors.append(x.clip_image_tensor)
|
||||||
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
|
self.clip_image_tensor = torch.cat(
|
||||||
|
[x.unsqueeze(0) for x in clip_image_tensors]
|
||||||
|
)
|
||||||
|
|
||||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||||
# find one to use as a base
|
# find one to use as a base
|
||||||
@@ -259,10 +330,14 @@ class DataLoaderBatchDTO:
|
|||||||
unaugmented_tensor = []
|
unaugmented_tensor = []
|
||||||
for x in self.file_items:
|
for x in self.file_items:
|
||||||
if x.unaugmented_tensor is None:
|
if x.unaugmented_tensor is None:
|
||||||
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
|
unaugmented_tensor.append(
|
||||||
|
torch.zeros_like(base_unaugmented_tensor)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
self.unaugmented_tensor = torch.cat(
|
||||||
|
[x.unsqueeze(0) for x in unaugmented_tensor]
|
||||||
|
)
|
||||||
|
|
||||||
# add unconditional tensors
|
# add unconditional tensors
|
||||||
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
if any([x.unconditional_tensor is not None for x in self.file_items]):
|
||||||
@@ -275,10 +350,14 @@ class DataLoaderBatchDTO:
|
|||||||
unconditional_tensor = []
|
unconditional_tensor = []
|
||||||
for x in self.file_items:
|
for x in self.file_items:
|
||||||
if x.unconditional_tensor is None:
|
if x.unconditional_tensor is None:
|
||||||
unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
|
unconditional_tensor.append(
|
||||||
|
torch.zeros_like(base_unconditional_tensor)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
unconditional_tensor.append(x.unconditional_tensor)
|
unconditional_tensor.append(x.unconditional_tensor)
|
||||||
self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
|
self.unconditional_tensor = torch.cat(
|
||||||
|
[x.unsqueeze(0) for x in unconditional_tensor]
|
||||||
|
)
|
||||||
|
|
||||||
if any([x.clip_image_embeds is not None for x in self.file_items]):
|
if any([x.clip_image_embeds is not None for x in self.file_items]):
|
||||||
self.clip_image_embeds = []
|
self.clip_image_embeds = []
|
||||||
@@ -288,13 +367,19 @@ class DataLoaderBatchDTO:
|
|||||||
else:
|
else:
|
||||||
raise Exception("clip_image_embeds is None for some file items")
|
raise Exception("clip_image_embeds is None for some file items")
|
||||||
|
|
||||||
if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
|
if any(
|
||||||
|
[x.clip_image_embeds_unconditional is not None for x in self.file_items]
|
||||||
|
):
|
||||||
self.clip_image_embeds_unconditional = []
|
self.clip_image_embeds_unconditional = []
|
||||||
for x in self.file_items:
|
for x in self.file_items:
|
||||||
if x.clip_image_embeds_unconditional is not None:
|
if x.clip_image_embeds_unconditional is not None:
|
||||||
self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
|
self.clip_image_embeds_unconditional.append(
|
||||||
|
x.clip_image_embeds_unconditional
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("clip_image_embeds_unconditional is None for some file items")
|
raise Exception(
|
||||||
|
"clip_image_embeds_unconditional is None for some file items"
|
||||||
|
)
|
||||||
|
|
||||||
if any([x.prompt_embeds is not None for x in self.file_items]):
|
if any([x.prompt_embeds is not None for x in self.file_items]):
|
||||||
# find one to use as a base
|
# find one to use as a base
|
||||||
@@ -331,7 +416,6 @@ class DataLoaderBatchDTO:
|
|||||||
audio_tensors.append(x.audio_tensor)
|
audio_tensors.append(x.audio_tensor)
|
||||||
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
self.audio_tensor = torch.cat([x.unsqueeze(0) for x in audio_tensors])
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
@@ -343,18 +427,12 @@ class DataLoaderBatchDTO:
|
|||||||
return [x.network_weight for x in self.file_items]
|
return [x.network_weight for x in self.file_items]
|
||||||
|
|
||||||
def get_caption_list(
|
def get_caption_list(
|
||||||
self,
|
self, trigger=None, to_replace_list=None, add_if_not_present=True
|
||||||
trigger=None,
|
|
||||||
to_replace_list=None,
|
|
||||||
add_if_not_present=True
|
|
||||||
):
|
):
|
||||||
return [x.caption for x in self.file_items]
|
return [x.caption for x in self.file_items]
|
||||||
|
|
||||||
def get_caption_short_list(
|
def get_caption_short_list(
|
||||||
self,
|
self, trigger=None, to_replace_list=None, add_if_not_present=True
|
||||||
trigger=None,
|
|
||||||
to_replace_list=None,
|
|
||||||
add_if_not_present=True
|
|
||||||
):
|
):
|
||||||
return [x.caption_short for x in self.file_items]
|
return [x.caption_short for x in self.file_items]
|
||||||
|
|
||||||
@@ -366,11 +444,13 @@ class DataLoaderBatchDTO:
|
|||||||
del self.audio_data
|
del self.audio_data
|
||||||
del self.audio_target
|
del self.audio_target
|
||||||
del self.audio_pred
|
del self.audio_pred
|
||||||
|
del self.first_frame_latents
|
||||||
|
del self.audio_latents
|
||||||
for file_item in self.file_items:
|
for file_item in self.file_items:
|
||||||
file_item.cleanup()
|
file_item.cleanup()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_config(self) -> 'DatasetConfig':
|
def dataset_config(self) -> "DatasetConfig":
|
||||||
if len(self.file_items) > 0:
|
if len(self.file_items) > 0:
|
||||||
return self.file_items[0].dataset_config
|
return self.file_items[0].dataset_config
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -456,8 +456,6 @@ class ImageProcessingDTOMixin:
|
|||||||
transform: Union[None, transforms.Compose],
|
transform: Union[None, transforms.Compose],
|
||||||
only_load_latents=False
|
only_load_latents=False
|
||||||
):
|
):
|
||||||
if self.is_latent_cached:
|
|
||||||
raise Exception('Latent caching not supported for videos')
|
|
||||||
|
|
||||||
if self.augments is not None and len(self.augments) > 0:
|
if self.augments is not None and len(self.augments) > 0:
|
||||||
raise Exception('Augments not supported for videos')
|
raise Exception('Augments not supported for videos')
|
||||||
@@ -727,9 +725,6 @@ class ImageProcessingDTOMixin:
|
|||||||
transform: Union[None, transforms.Compose],
|
transform: Union[None, transforms.Compose],
|
||||||
only_load_latents=False
|
only_load_latents=False
|
||||||
):
|
):
|
||||||
if self.dataset_config.num_frames > 1:
|
|
||||||
self.load_and_process_video(transform, only_load_latents)
|
|
||||||
return
|
|
||||||
# handle get_prompt_embedding
|
# handle get_prompt_embedding
|
||||||
if self.is_text_embedding_cached:
|
if self.is_text_embedding_cached:
|
||||||
self.load_prompt_embedding()
|
self.load_prompt_embedding()
|
||||||
@@ -747,6 +742,9 @@ class ImageProcessingDTOMixin:
|
|||||||
if self.has_unconditional:
|
if self.has_unconditional:
|
||||||
self.load_unconditional_image()
|
self.load_unconditional_image()
|
||||||
return
|
return
|
||||||
|
if self.dataset_config.num_frames > 1:
|
||||||
|
self.load_and_process_video(transform, only_load_latents)
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
img = Image.open(self.path)
|
img = Image.open(self.path)
|
||||||
img = exif_transpose(img)
|
img = exif_transpose(img)
|
||||||
@@ -1716,6 +1714,8 @@ class LatentCachingFileItemDTOMixin:
|
|||||||
if hasattr(super(), '__init__'):
|
if hasattr(super(), '__init__'):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._encoded_latent: Union[torch.Tensor, None] = None
|
self._encoded_latent: Union[torch.Tensor, None] = None
|
||||||
|
self._cached_first_frame_latent: Union[torch.Tensor, None] = None
|
||||||
|
self._cached_audio_latent: Union[torch.Tensor, None] = None
|
||||||
self._latent_path: Union[str, None] = None
|
self._latent_path: Union[str, None] = None
|
||||||
self.is_latent_cached = False
|
self.is_latent_cached = False
|
||||||
self.is_caching_to_disk = False
|
self.is_caching_to_disk = False
|
||||||
@@ -1745,6 +1745,14 @@ class LatentCachingFileItemDTOMixin:
|
|||||||
item["flip_y"] = True
|
item["flip_y"] = True
|
||||||
if self.dataset_config.num_frames > 1:
|
if self.dataset_config.num_frames > 1:
|
||||||
item["num_frames"] = self.dataset_config.num_frames
|
item["num_frames"] = self.dataset_config.num_frames
|
||||||
|
if self.dataset_config.do_i2v:
|
||||||
|
item["do_i2v"] = True
|
||||||
|
if self.dataset_config.do_audio:
|
||||||
|
item["do_audio"] = True
|
||||||
|
if self.dataset_config.audio_normalize:
|
||||||
|
item["audio_normalize"] = True
|
||||||
|
if self.dataset_config.audio_preserve_pitch:
|
||||||
|
item["audio_preserve_pitch"] = True
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||||
@@ -1769,9 +1777,15 @@ class LatentCachingFileItemDTOMixin:
|
|||||||
if not self.is_caching_to_memory:
|
if not self.is_caching_to_memory:
|
||||||
# we are caching on disk, don't save in memory
|
# we are caching on disk, don't save in memory
|
||||||
self._encoded_latent = None
|
self._encoded_latent = None
|
||||||
|
self._cached_first_frame_latent = None
|
||||||
|
self._cached_audio_latent = None
|
||||||
else:
|
else:
|
||||||
# move it back to cpu
|
# move it back to cpu
|
||||||
self._encoded_latent = self._encoded_latent.to('cpu')
|
self._encoded_latent = self._encoded_latent.to('cpu')
|
||||||
|
if self._cached_first_frame_latent is not None:
|
||||||
|
self._cached_first_frame_latent = self._cached_first_frame_latent.to('cpu')
|
||||||
|
if self._cached_audio_latent is not None:
|
||||||
|
self._cached_audio_latent = self._cached_audio_latent.to('cpu')
|
||||||
|
|
||||||
def get_latent(self, device=None):
|
def get_latent(self, device=None):
|
||||||
if not self.is_latent_cached:
|
if not self.is_latent_cached:
|
||||||
@@ -1784,6 +1798,10 @@ class LatentCachingFileItemDTOMixin:
|
|||||||
device='cpu'
|
device='cpu'
|
||||||
)
|
)
|
||||||
self._encoded_latent = state_dict['latent']
|
self._encoded_latent = state_dict['latent']
|
||||||
|
if 'first_frame_latent' in state_dict:
|
||||||
|
self._cached_first_frame_latent = state_dict['first_frame_latent']
|
||||||
|
if 'audio_latent' in state_dict:
|
||||||
|
self._cached_audio_latent = state_dict['audio_latent']
|
||||||
return self._encoded_latent
|
return self._encoded_latent
|
||||||
|
|
||||||
|
|
||||||
@@ -1795,8 +1813,6 @@ class LatentCachingMixin:
|
|||||||
self.latent_cache = {}
|
self.latent_cache = {}
|
||||||
|
|
||||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||||
if self.dataset_config.num_frames > 1:
|
|
||||||
raise Exception("Error: caching latents is not supported for multi-frame datasets")
|
|
||||||
with accelerator.main_process_first():
|
with accelerator.main_process_first():
|
||||||
print_acc(f"Caching latents for {self.dataset_path}")
|
print_acc(f"Caching latents for {self.dataset_path}")
|
||||||
# cache all latents to disk
|
# cache all latents to disk
|
||||||
@@ -1839,25 +1855,50 @@ class LatentCachingMixin:
|
|||||||
# load it into memory
|
# load it into memory
|
||||||
state_dict = load_file(latent_path, device='cpu')
|
state_dict = load_file(latent_path, device='cpu')
|
||||||
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||||
|
if 'first_frame_latent' in state_dict:
|
||||||
|
file_item._cached_first_frame_latent = state_dict['first_frame_latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||||
|
if 'audio_latent' in state_dict:
|
||||||
|
file_item._cached_audio_latent = state_dict['audio_latent'].to('cpu', dtype=self.sd.torch_dtype)
|
||||||
else:
|
else:
|
||||||
# not saved to disk, calculate
|
# not saved to disk, calculate
|
||||||
# load the image first
|
# load the image first
|
||||||
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
file_item.load_and_process_image(self.transform, only_load_latents=True)
|
||||||
dtype = self.sd.torch_dtype
|
dtype = self.sd.torch_dtype
|
||||||
device = self.sd.device_torch
|
device = self.sd.device_torch
|
||||||
|
state_dict = OrderedDict()
|
||||||
|
first_frame_latent = None
|
||||||
|
audio_latent = None
|
||||||
# add batch dimension
|
# add batch dimension
|
||||||
try:
|
try:
|
||||||
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||||
latent = self.sd.encode_images(imgs).squeeze(0)
|
latent = self.sd.encode_images(imgs).squeeze(0)
|
||||||
|
if to_disk:
|
||||||
|
state_dict['latent'] = latent.clone().detach().cpu()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_acc(f"Error processing image: {file_item.path}")
|
print_acc(f"Error processing image: {file_item.path}")
|
||||||
print_acc(f"Error: {str(e)}")
|
print_acc(f"Error: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
# do first frame
|
||||||
|
if self.dataset_config.num_frames > 1 and self.dataset_config.do_i2v:
|
||||||
|
frames = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
|
||||||
|
if len(frames.shape) == 4:
|
||||||
|
first_frames = frames
|
||||||
|
elif len(frames.shape) == 5:
|
||||||
|
first_frames = frames[:, 0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown frame shape {frames.shape}")
|
||||||
|
first_frame_latent = self.sd.encode_images(first_frames).squeeze(0)
|
||||||
|
if to_disk:
|
||||||
|
state_dict['first_frame_latent'] = first_frame_latent.clone().detach().cpu()
|
||||||
|
|
||||||
|
# audio
|
||||||
|
if file_item.audio_data is not None:
|
||||||
|
audio_latent = self.sd.encode_audio([file_item.audio_data]).squeeze(0)
|
||||||
|
if to_disk:
|
||||||
|
state_dict['audio_latent'] = audio_latent.clone().detach().cpu()
|
||||||
|
|
||||||
# save_latent
|
# save_latent
|
||||||
if to_disk:
|
if to_disk:
|
||||||
state_dict = OrderedDict([
|
|
||||||
('latent', latent.clone().detach().cpu()),
|
|
||||||
])
|
|
||||||
# metadata
|
# metadata
|
||||||
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
|
||||||
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
os.makedirs(os.path.dirname(latent_path), exist_ok=True)
|
||||||
@@ -1866,17 +1907,18 @@ class LatentCachingMixin:
|
|||||||
if to_memory:
|
if to_memory:
|
||||||
# keep it in memory
|
# keep it in memory
|
||||||
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||||
|
if first_frame_latent is not None:
|
||||||
|
file_item._cached_first_frame_latent = first_frame_latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||||
|
if audio_latent is not None:
|
||||||
|
file_item._cached_audio_latent = audio_latent.to('cpu', dtype=self.sd.torch_dtype)
|
||||||
|
|
||||||
del imgs
|
del imgs
|
||||||
del latent
|
del latent
|
||||||
del file_item.tensor
|
del file_item.tensor
|
||||||
|
file_item.cleanup()
|
||||||
|
|
||||||
# flush(garbage_collect=False)
|
|
||||||
file_item.is_latent_cached = True
|
file_item.is_latent_cached = True
|
||||||
i += 1
|
i += 1
|
||||||
# flush every 100
|
|
||||||
# if i % 100 == 0:
|
|
||||||
# flush()
|
|
||||||
|
|
||||||
# restore device state
|
# restore device state
|
||||||
self.sd.restore_device_state()
|
self.sd.restore_device_state()
|
||||||
|
|||||||
@@ -1123,6 +1123,10 @@ class BaseModel:
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def encode_audio(self, audio_data_list):
|
||||||
|
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
|
||||||
|
raise NotImplementedError("Audio encoding not implemented for this model.")
|
||||||
|
|
||||||
def decode_latents(
|
def decode_latents(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
|
|||||||
@@ -2551,6 +2551,10 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def encode_audio(self, audio_data_list):
|
||||||
|
# audio_date_list is a list of {"waveform": waveform[C, L], "sample_rate": int(sample_rate)}
|
||||||
|
raise NotImplementedError("Audio encoding not implemented for this model.")
|
||||||
|
|
||||||
def decode_latents(
|
def decode_latents(
|
||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
VERSION = "0.7.18"
|
VERSION = "0.7.19"
|
||||||
|
|||||||
Reference in New Issue
Block a user