diff --git a/requirements.txt b/requirements.txt index 3fbf7da7..a9831469 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ huggingface_hub peft gradio python-slugify +opencv-python \ No newline at end of file diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 31d97f2d..5904968d 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from toolkit.paths import SD_SCRIPTS_ROOT import torchvision.transforms.functional -from toolkit.image_utils import show_img, show_tensors +from toolkit.image_utils import save_tensors, show_img, show_tensors sys.path.append(SD_SCRIPTS_ROOT) @@ -28,13 +28,18 @@ from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('dataset_folder', type=str, default='input') parser.add_argument('--epochs', type=int, default=1) - +parser.add_argument('--num_frames', type=int, default=1) +parser.add_argument('--output_path', type=str, default=None) args = parser.parse_args() +if args.output_path is not None: + args.output_path = os.path.abspath(args.output_path) + os.makedirs(args.output_path, exist_ok=True) + dataset_folder = args.dataset_folder -resolution = 1024 +resolution = 512 bucket_tolerance = 64 batch_size = 1 @@ -63,6 +68,8 @@ dataset_config = DatasetConfig( # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', buckets=True, bucket_tolerance=bucket_tolerance, + shrink_video_to_frames=True, + num_frames=args.num_frames, # poi='person', # shuffle_augmentations=True, # augmentations=[ @@ -80,11 +87,17 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si # run through an epoch ang check sizes dataloader_iterator = iter(dataloader) +idx = 0 for epoch in range(args.epochs): for batch in tqdm(dataloader): batch: 'DataLoaderBatchDTO' img_batch = batch.tensor - batch_size, channels, height, width = img_batch.shape + frames = 1 + if len(img_batch.shape) == 5: + frames = img_batch.shape[1] + batch_size, frames, channels, height, width = img_batch.shape + else: + batch_size, channels, height, width = img_batch.shape # img_batch = color_block_imgs(img_batch, neg1_1=True) @@ -110,15 +123,18 @@ for epoch in range(args.epochs): big_img = img_batch # big_img = big_img.clamp(-1, 1) + if args.output_path is not None: + save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png')) + else: + show_tensors(big_img) - show_tensors(big_img) + # convert to image + # img = transforms.ToPILImage()(big_img) + # + # show_img(img) - # convert to image - # img = transforms.ToPILImage()(big_img) - # - # show_img(img) - - time.sleep(0.2) + time.sleep(0.2) + idx += 1 # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/buckets.py b/toolkit/buckets.py index 835c9eb9..3b0cbf19 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -56,51 +56,6 @@ resolutions_1024: List[BucketResolution] = [ {"width": 128, "height": 8192}, ] -# Even numbers so they can be patched easier -resolutions_dit_1024: List[BucketResolution] = [ - # Base resolution - {"width": 1024, "height": 1024}, - # widescreen - {"width": 2048, "height": 512}, - {"width": 1792, "height": 576}, - {"width": 1728, "height": 576}, - {"width": 1664, "height": 576}, - {"width": 1600, "height": 640}, - {"width": 1536, "height": 640}, - {"width": 1472, "height": 704}, - {"width": 1408, "height": 704}, - {"width": 1344, "height": 704}, - {"width": 1344, "height": 768}, - {"width": 1280, "height": 768}, - {"width": 1216, "height": 832}, - {"width": 1152, "height": 832}, - {"width": 1152, "height": 896}, - {"width": 1088, "height": 896}, - {"width": 1088, "height": 960}, - {"width": 1024, "height": 960}, - # portrait - {"width": 960, "height": 1024}, - {"width": 960, "height": 1088}, - {"width": 896, "height": 1088}, - {"width": 896, "height": 1152}, # 2:3 - {"width": 832, "height": 1152}, - {"width": 832, "height": 1216}, - {"width": 768, "height": 1280}, - {"width": 768, "height": 1344}, - {"width": 704, "height": 1408}, - {"width": 704, "height": 1472}, - {"width": 640, "height": 1536}, - {"width": 640, "height": 1600}, - {"width": 576, "height": 1664}, - {"width": 576, "height": 1728}, - {"width": 576, "height": 1792}, - {"width": 512, "height": 1856}, - {"width": 512, "height": 1920}, - {"width": 512, "height": 1984}, - {"width": 512, "height": 2048}, -] - - def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]: # determine scaler form 1024 to resolution scaler = resolution / 1024 @@ -171,4 +126,4 @@ def get_bucket_for_image_size( if closest_bucket is None: raise ValueError("No suitable bucket found") - return closest_bucket + return closest_bucket \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d67b79eb..a3bcd1ce 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -763,6 +763,22 @@ class DatasetConfig: self.square_crop: bool = kwargs.get('square_crop', False) # apply same augmentations to control images. Usually want this true unless special case self.replay_transforms: bool = kwargs.get('replay_transforms', True) + + # for video + # if num_frames is greater than 1, the dataloader will look for video files. + # num_frames will be the number of frames in the training batch. If num_frames is 1, it will look for images + self.num_frames: int = kwargs.get('num_frames', 1) + # if true, will shrink video to our frames. For instance, if we have a video with 100 frames and num_frames is 10, + # we would pull frame 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 so they are evenly spaced + self.shrink_video_to_frames: bool = kwargs.get('shrink_video_to_frames', True) + # fps is only used if shrink_video_to_frames is false. This will attempt to pull the num_frames at the given fps + # it will select a random start frame and pull the frames at the given fps + # this could have various issues with shorter videos and videos with variable fps + # I recommend trimming your videos to the desired length and using shrink_video_to_frames(default) + self.fps: int = kwargs.get('fps', 16) + + # debug the frame count and frame selection. You dont need this. It is for debugging. + self.debug: bool = kwargs.get('debug', False) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 1fd6a3b8..d3845d2a 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -30,6 +30,10 @@ def is_native_windows(): if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion + + +image_extensions = ['.jpg', '.jpeg', '.png', '.webp'] +video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv'] class RescaleTransform: @@ -376,8 +380,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti batch_size=1, sd: 'StableDiffusion' = None, ): - super().__init__() self.dataset_config = dataset_config + self.is_video = dataset_config.num_frames > 1 + super().__init__() folder_path = dataset_config.folder_path self.dataset_path = dataset_config.dataset_path if self.dataset_path is None: @@ -407,7 +412,11 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti # check if dataset_path is a folder or json if os.path.isdir(self.dataset_path): - file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] + extensions = image_extensions + if self.is_video: + # only look for videos + extensions = video_extensions + file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions))] else: # assume json with open(self.dataset_path, 'r') as f: @@ -438,7 +447,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti # this might take a while print_acc(f"Dataset: {self.dataset_path}") - print_acc(f" - Preprocessing image dimensions") + if self.is_video: + print_acc(f" - Preprocessing video dimensions") + else: + print_acc(f" - Preprocessing image dimensions") dataset_folder = self.dataset_path if not os.path.isdir(self.dataset_path): dataset_folder = os.path.dirname(dataset_folder) @@ -477,17 +489,23 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.file_list.append(file_item) except Exception as e: print_acc(traceback.format_exc()) - print_acc(f"Error processing image: {file}") + if self.is_video: + print_acc(f"Error processing video: {file}") + else: + print_acc(f"Error processing image: {file}") print_acc(e) bad_count += 1 # save the size database with open(dataset_size_file, 'w') as f: json.dump(self.size_database, f) - - print_acc(f" - Found {len(self.file_list)} images") - # print_acc(f" - Found {bad_count} images that are too small") - assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" + + if self.is_video: + print_acc(f" - Found {len(self.file_list)} videos") + assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}" + else: + print_acc(f" - Found {len(self.file_list)} images") + assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" # handle x axis flips if self.dataset_config.flip_x: @@ -510,8 +528,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti self.file_list.append(new_file_item) if self.dataset_config.flip_x or self.dataset_config.flip_y: - print_acc(f" - Found {len(self.file_list)} images after adding flips") - + if self.is_video: + print_acc(f" - Found {len(self.file_list)} videos after adding flips") + else: + print_acc(f" - Found {len(self.file_list)} images after adding flips") self.setup_epoch() @@ -539,7 +559,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti return len(self.file_list) def _get_single_item(self, index) -> 'FileItemDTO': - file_item = copy.deepcopy(self.file_list[index]) + file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) file_item.load_and_process_image(self.transform) file_item.load_caption(self.caption_dict) return file_item diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 34239f40..095fa014 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -2,6 +2,7 @@ import os import weakref from _weakref import ReferenceType from typing import TYPE_CHECKING, List, Union +import cv2 import torch import random @@ -43,6 +44,7 @@ class FileItemDTO( def __init__(self, *args, **kwargs): self.path = kwargs.get('path', '') self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.is_video = self.dataset_config.num_frames > 1 size_database = kwargs.get('size_database', {}) dataset_root = kwargs.get('dataset_root', None) if dataset_root is not None: @@ -52,6 +54,21 @@ class FileItemDTO( file_key = os.path.basename(self.path) if file_key in size_database: w, h = size_database[file_key] + elif self.is_video: + # Open the video file + video = cv2.VideoCapture(self.path) + + # Check if video opened successfully + if not video.isOpened(): + raise Exception(f"Error: Could not open video file {self.path}") + + # Get width and height + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Release the video capture object immediately + video.release() + size_database[file_key] = (width, height) else: # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. # process width and height diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3e655554..f57fb263 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -7,6 +7,7 @@ import os import random from collections import OrderedDict from typing import TYPE_CHECKING, List, Dict, Union +import traceback import cv2 import numpy as np @@ -430,11 +431,205 @@ class CaptionProcessingDTOMixin: class ImageProcessingDTOMixin: + def load_and_process_video( + self: 'FileItemDTO', + transform: Union[None, transforms.Compose], + only_load_latents=False + ): + if self.is_latent_cached: + raise Exception('Latent caching not supported for videos') + + if self.augments is not None and len(self.augments) > 0: + raise Exception('Augments not supported for videos') + + if self.has_augmentations: + raise Exception('Augmentations not supported for videos') + + if not self.dataset_config.buckets: + raise Exception('Buckets required for video processing') + + try: + # Use OpenCV to capture video frames + cap = cv2.VideoCapture(self.path) + + if not cap.isOpened(): + raise Exception(f"Failed to open video file: {self.path}") + + # Get video properties + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = cap.get(cv2.CAP_PROP_FPS) + + # Calculate the max valid frame index (accounting for zero-indexing) + max_frame_index = total_frames - 1 + + # Only log video properties if in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Video properties: {self.path}") + print_acc(f" Total frames: {total_frames}") + print_acc(f" Max valid frame index: {max_frame_index}") + print_acc(f" FPS: {video_fps}") + + frames_to_extract = [] + + # Always stretch/shrink to the requested number of frames if needed + if self.dataset_config.shrink_video_to_frames or total_frames < self.dataset_config.num_frames: + # Distribute frames evenly across the entire video + interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0 + frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)] + else: + # Calculate frame interval based on FPS ratio + fps_ratio = video_fps / self.dataset_config.fps + frame_interval = max(1, int(round(fps_ratio))) + + # Calculate max consecutive frames we can extract at desired FPS + max_consecutive_frames = (total_frames // frame_interval) + + if max_consecutive_frames < self.dataset_config.num_frames: + # Not enough frames at desired FPS, so stretch instead + interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0 + frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)] + else: + # Calculate max start frame to ensure we can get all num_frames + max_start_frame = max_frame_index - ((self.dataset_config.num_frames - 1) * frame_interval) + start_frame = random.randint(0, max(0, max_start_frame)) + + # Generate list of frames to extract + frames_to_extract = [start_frame + (i * frame_interval) for i in range(self.dataset_config.num_frames)] + + # Final safety check - ensure no frame exceeds max valid index + frames_to_extract = [min(frame_idx, max_frame_index) for frame_idx in frames_to_extract] + + # Only log frames to extract if in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f" Frames to extract: {frames_to_extract}") + + # Extract frames + frames = [] + for frame_idx in frames_to_extract: + # Safety check - ensure frame_idx is within bounds (silently fix) + if frame_idx > max_frame_index: + frame_idx = max_frame_index + + # Set frame position + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + + # Silently verify position was set correctly (no warnings unless debug mode) + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + actual_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + if actual_pos != frame_idx: + print_acc(f"Warning: Failed to set exact frame position. Requested: {frame_idx}, Actual: {actual_pos}") + + ret, frame = cap.read() + if not ret: + # Try to provide more detailed error information + actual_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) + frame_pos_info = f"Requested frame: {frame_idx}, Actual frame position: {actual_frame}" + + # Try to read the next available frame as a fallback + fallback_success = False + for fallback_offset in [1, -1, 5, -5, 10, -10]: + fallback_pos = max(0, min(frame_idx + fallback_offset, max_frame_index)) + cap.set(cv2.CAP_PROP_POS_FRAMES, fallback_pos) + fallback_ret, fallback_frame = cap.read() + if fallback_ret: + # Only log in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Falling back to nearby frame {fallback_pos} instead of {frame_idx}") + frame = fallback_frame + fallback_success = True + break + else: + # No fallback worked, raise a more detailed exception + video_info = f"Video: {self.path}, Total frames: {total_frames}, FPS: {video_fps}" + raise Exception(f"Failed to read frame {frame_idx} from video. {frame_pos_info}. {video_info}") + + # Convert BGR to RGB + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Convert to PIL Image + img = Image.fromarray(frame) + + # Apply the same processing as for single images + img = img.convert('RGB') + + if self.flip_x: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + + # Apply bucketing + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + + # Apply transform if provided + if transform: + img = transform(img) + + frames.append(img) + + # Release the video capture + cap.release() + + # Stack frames into tensor [frames, channels, height, width] + self.tensor = torch.stack(frames) + + # Only log success in debug mode + if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug: + print_acc(f"Successfully loaded video with {len(frames)} frames: {self.path}") + + except Exception as e: + # Print full traceback + traceback.print_exc() + + # Provide more context about the error + error_msg = str(e) + try: + if 'Failed to read frame' in error_msg and cap is not None: + # Try to get more info about the video that failed + cap_status = "Opened" if cap.isOpened() else "Closed" + current_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if cap.isOpened() else "Unknown" + reported_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if cap.isOpened() else "Unknown" + + print_acc(f"Video details when error occurred:") + print_acc(f" Cap status: {cap_status}") + print_acc(f" Current position: {current_pos}") + print_acc(f" Reported total frames: {reported_total}") + + # Try to verify if the video is corrupted + if cap.isOpened(): + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Go to start + start_ret, _ = cap.read() + + # Try to read the last frame to check if it's accessible + if reported_total > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, reported_total - 1) + end_ret, _ = cap.read() + print_acc(f" Can read first frame: {start_ret}, Can read last frame: {end_ret}") + + # Close the cap if it's still open + cap.release() + except Exception as debug_err: + print_acc(f"Error during error diagnosis: {debug_err}") + + print_acc(f"Error: {error_msg}") + print_acc(f"Error loading video: {self.path}") + + # Re-raise with more detailed information + raise Exception(f"Video loading error ({self.path}): {error_msg}") from e + def load_and_process_image( self: 'FileItemDTO', transform: Union[None, transforms.Compose], only_load_latents=False ): + if self.dataset_config.num_frames > 1: + self.load_and_process_video(transform, only_load_latents) + return # if we are caching latents, just do that if self.is_latent_cached: self.get_latent() @@ -1379,6 +1574,8 @@ class LatentCachingMixin: self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): + if self.dataset_config.num_frames > 1: + raise Exception("Error: caching latents is not supported for multi-frame datasets") with accelerator.main_process_first(): print_acc(f"Caching latents for {self.dataset_path}") # cache all latents to disk @@ -1409,7 +1606,7 @@ class LatentCachingMixin: elif self.sd.model_config.is_pixart_sigma: file_item.latent_space_version = 'sdxl' else: - file_item.latent_space_version = 'sd1' + file_item.latent_space_version = self.sd.model_config.arch file_item.is_caching_to_disk = to_disk file_item.is_caching_to_memory = to_memory file_item.latent_load_device = self.sd.device diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index 9b9f306e..0c536c43 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -12,6 +12,7 @@ import cv2 import numpy as np import torch from diffusers import AutoencoderTiny +from PIL import Image as PILImage FILE_UNKNOWN = "Sorry, don't know how to get size for this file." @@ -480,7 +481,26 @@ def show_tensors(imgs: torch.Tensor, name='AI Toolkit'): img_numpy = img_numpy.astype(np.uint8) show_img(img_numpy[0], name=name) + +def save_tensors(imgs: torch.Tensor, path='output.png'): + if len(imgs.shape) == 5 and imgs.shape[0] == 1: + imgs = imgs.squeeze(0) + if len(imgs.shape) == 4: + img_list = torch.chunk(imgs, imgs.shape[0], dim=0) + else: + img_list = [imgs] + img = torch.cat(img_list, dim=3) + img = img / 2 + 0.5 + img_numpy = img.to(torch.float32).detach().cpu().numpy() + img_numpy = np.clip(img_numpy, 0, 1) * 255 + img_numpy = img_numpy.transpose(0, 2, 3, 1) + img_numpy = img_numpy.astype(np.uint8) + # concat images to one + img_numpy = np.concatenate(img_numpy, axis=1) + # conver to pil + img_pil = PILImage.fromarray(img_numpy) + img_pil.save(path) def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'): if vae.device == 'cpu': diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 886cbcc6..e6c52f10 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -585,7 +585,6 @@ class Wan21(BaseModel): if dtype is None: dtype = self.vae_torch_dtype - latent_list = [] # Move to vae to device if on cpu if self.vae.device == 'cpu': self.vae.to(device) @@ -593,18 +592,43 @@ class Wan21(BaseModel): self.vae.requires_grad_(False) # move to device and dtype image_list = [image.to(device, dtype=dtype) for image in image_list] + + # We need to detect video if we have it. + # videos come in (num_frames, channels, height, width) + # images come in (channels, height, width) + # we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width) + + if len(image_list[0].shape) == 3: + image_list = [image.unsqueeze(1) for image in image_list] + elif len(image_list[0].shape) == 4: + image_list = [image.permute(1, 0, 2, 3) for image in image_list] + else: + raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}") VAE_SCALE_FACTOR = 8 # resize images if not divisible by 8 + # now we need to resize considering the shape (channels, num_frames, height, width) for i in range(len(image_list)): image = image_list[i] - if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: - image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, - image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0: + # Create resized frames by handling each frame separately + c, f, h, w = image.shape + target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR + target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR + + # We need to process each frame separately + resized_frames = [] + for frame_idx in range(f): + frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width) + resized_frame = Resize((target_h, target_w))(frame) + resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back + + # Concatenate all frames back together along the frame dimension + image_list[i] = torch.cat(resized_frames, dim=1) images = torch.stack(image_list) - images = images.unsqueeze(2) + # images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w) latents = self.vae.encode(images).latent_dist.sample() latents_mean = (