mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added ability to load video datasets and train with them
This commit is contained in:
@@ -33,3 +33,4 @@ huggingface_hub
|
||||
peft
|
||||
gradio
|
||||
python-slugify
|
||||
opencv-python
|
||||
@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import show_img, show_tensors
|
||||
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -28,13 +28,18 @@ from tqdm import tqdm
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
parser.add_argument('--num_frames', type=int, default=1)
|
||||
parser.add_argument('--output_path', type=str, default=None)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_path is not None:
|
||||
args.output_path = os.path.abspath(args.output_path)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 1024
|
||||
resolution = 512
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
|
||||
@@ -63,6 +68,8 @@ dataset_config = DatasetConfig(
|
||||
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
shrink_video_to_frames=True,
|
||||
num_frames=args.num_frames,
|
||||
# poi='person',
|
||||
# shuffle_augmentations=True,
|
||||
# augmentations=[
|
||||
@@ -80,11 +87,17 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
idx = 0
|
||||
for epoch in range(args.epochs):
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
frames = 1
|
||||
if len(img_batch.shape) == 5:
|
||||
frames = img_batch.shape[1]
|
||||
batch_size, frames, channels, height, width = img_batch.shape
|
||||
else:
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
@@ -110,15 +123,18 @@ for epoch in range(args.epochs):
|
||||
|
||||
big_img = img_batch
|
||||
# big_img = big_img.clamp(-1, 1)
|
||||
if args.output_path is not None:
|
||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
||||
else:
|
||||
show_tensors(big_img)
|
||||
|
||||
show_tensors(big_img)
|
||||
# convert to image
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
# convert to image
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
time.sleep(0.2)
|
||||
time.sleep(0.2)
|
||||
idx += 1
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -56,51 +56,6 @@ resolutions_1024: List[BucketResolution] = [
|
||||
{"width": 128, "height": 8192},
|
||||
]
|
||||
|
||||
# Even numbers so they can be patched easier
|
||||
resolutions_dit_1024: List[BucketResolution] = [
|
||||
# Base resolution
|
||||
{"width": 1024, "height": 1024},
|
||||
# widescreen
|
||||
{"width": 2048, "height": 512},
|
||||
{"width": 1792, "height": 576},
|
||||
{"width": 1728, "height": 576},
|
||||
{"width": 1664, "height": 576},
|
||||
{"width": 1600, "height": 640},
|
||||
{"width": 1536, "height": 640},
|
||||
{"width": 1472, "height": 704},
|
||||
{"width": 1408, "height": 704},
|
||||
{"width": 1344, "height": 704},
|
||||
{"width": 1344, "height": 768},
|
||||
{"width": 1280, "height": 768},
|
||||
{"width": 1216, "height": 832},
|
||||
{"width": 1152, "height": 832},
|
||||
{"width": 1152, "height": 896},
|
||||
{"width": 1088, "height": 896},
|
||||
{"width": 1088, "height": 960},
|
||||
{"width": 1024, "height": 960},
|
||||
# portrait
|
||||
{"width": 960, "height": 1024},
|
||||
{"width": 960, "height": 1088},
|
||||
{"width": 896, "height": 1088},
|
||||
{"width": 896, "height": 1152}, # 2:3
|
||||
{"width": 832, "height": 1152},
|
||||
{"width": 832, "height": 1216},
|
||||
{"width": 768, "height": 1280},
|
||||
{"width": 768, "height": 1344},
|
||||
{"width": 704, "height": 1408},
|
||||
{"width": 704, "height": 1472},
|
||||
{"width": 640, "height": 1536},
|
||||
{"width": 640, "height": 1600},
|
||||
{"width": 576, "height": 1664},
|
||||
{"width": 576, "height": 1728},
|
||||
{"width": 576, "height": 1792},
|
||||
{"width": 512, "height": 1856},
|
||||
{"width": 512, "height": 1920},
|
||||
{"width": 512, "height": 1984},
|
||||
{"width": 512, "height": 2048},
|
||||
]
|
||||
|
||||
|
||||
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
|
||||
# determine scaler form 1024 to resolution
|
||||
scaler = resolution / 1024
|
||||
@@ -171,4 +126,4 @@ def get_bucket_for_image_size(
|
||||
if closest_bucket is None:
|
||||
raise ValueError("No suitable bucket found")
|
||||
|
||||
return closest_bucket
|
||||
return closest_bucket
|
||||
@@ -763,6 +763,22 @@ class DatasetConfig:
|
||||
self.square_crop: bool = kwargs.get('square_crop', False)
|
||||
# apply same augmentations to control images. Usually want this true unless special case
|
||||
self.replay_transforms: bool = kwargs.get('replay_transforms', True)
|
||||
|
||||
# for video
|
||||
# if num_frames is greater than 1, the dataloader will look for video files.
|
||||
# num_frames will be the number of frames in the training batch. If num_frames is 1, it will look for images
|
||||
self.num_frames: int = kwargs.get('num_frames', 1)
|
||||
# if true, will shrink video to our frames. For instance, if we have a video with 100 frames and num_frames is 10,
|
||||
# we would pull frame 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 so they are evenly spaced
|
||||
self.shrink_video_to_frames: bool = kwargs.get('shrink_video_to_frames', True)
|
||||
# fps is only used if shrink_video_to_frames is false. This will attempt to pull the num_frames at the given fps
|
||||
# it will select a random start frame and pull the frames at the given fps
|
||||
# this could have various issues with shorter videos and videos with variable fps
|
||||
# I recommend trimming your videos to the desired length and using shrink_video_to_frames(default)
|
||||
self.fps: int = kwargs.get('fps', 16)
|
||||
|
||||
# debug the frame count and frame selection. You dont need this. It is for debugging.
|
||||
self.debug: bool = kwargs.get('debug', False)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
|
||||
@@ -30,6 +30,10 @@ def is_native_windows():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv']
|
||||
|
||||
|
||||
class RescaleTransform:
|
||||
@@ -376,8 +380,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
self.is_video = dataset_config.num_frames > 1
|
||||
super().__init__()
|
||||
folder_path = dataset_config.folder_path
|
||||
self.dataset_path = dataset_config.dataset_path
|
||||
if self.dataset_path is None:
|
||||
@@ -407,7 +412,11 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# check if dataset_path is a folder or json
|
||||
if os.path.isdir(self.dataset_path):
|
||||
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
extensions = image_extensions
|
||||
if self.is_video:
|
||||
# only look for videos
|
||||
extensions = video_extensions
|
||||
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions))]
|
||||
else:
|
||||
# assume json
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
@@ -438,7 +447,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# this might take a while
|
||||
print_acc(f"Dataset: {self.dataset_path}")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
if self.is_video:
|
||||
print_acc(f" - Preprocessing video dimensions")
|
||||
else:
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
dataset_folder = self.dataset_path
|
||||
if not os.path.isdir(self.dataset_path):
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
@@ -477,17 +489,23 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print_acc(traceback.format_exc())
|
||||
print_acc(f"Error processing image: {file}")
|
||||
if self.is_video:
|
||||
print_acc(f"Error processing video: {file}")
|
||||
else:
|
||||
print_acc(f"Error processing image: {file}")
|
||||
print_acc(e)
|
||||
bad_count += 1
|
||||
|
||||
# save the size database
|
||||
with open(dataset_size_file, 'w') as f:
|
||||
json.dump(self.size_database, f)
|
||||
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
# print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
if self.is_video:
|
||||
print_acc(f" - Found {len(self.file_list)} videos")
|
||||
assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}"
|
||||
else:
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
@@ -510,8 +528,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
if self.is_video:
|
||||
print_acc(f" - Found {len(self.file_list)} videos after adding flips")
|
||||
else:
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
@@ -539,7 +559,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
return len(self.file_list)
|
||||
|
||||
def _get_single_item(self, index) -> 'FileItemDTO':
|
||||
file_item = copy.deepcopy(self.file_list[index])
|
||||
file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index])
|
||||
file_item.load_and_process_image(self.transform)
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import weakref
|
||||
from _weakref import ReferenceType
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import cv2
|
||||
import torch
|
||||
import random
|
||||
|
||||
@@ -43,6 +44,7 @@ class FileItemDTO(
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.path = kwargs.get('path', '')
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.is_video = self.dataset_config.num_frames > 1
|
||||
size_database = kwargs.get('size_database', {})
|
||||
dataset_root = kwargs.get('dataset_root', None)
|
||||
if dataset_root is not None:
|
||||
@@ -52,6 +54,21 @@ class FileItemDTO(
|
||||
file_key = os.path.basename(self.path)
|
||||
if file_key in size_database:
|
||||
w, h = size_database[file_key]
|
||||
elif self.is_video:
|
||||
# Open the video file
|
||||
video = cv2.VideoCapture(self.path)
|
||||
|
||||
# Check if video opened successfully
|
||||
if not video.isOpened():
|
||||
raise Exception(f"Error: Could not open video file {self.path}")
|
||||
|
||||
# Get width and height
|
||||
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Release the video capture object immediately
|
||||
video.release()
|
||||
size_database[file_key] = (width, height)
|
||||
else:
|
||||
# original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
|
||||
# process width and height
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -430,11 +431,205 @@ class CaptionProcessingDTOMixin:
|
||||
|
||||
|
||||
class ImageProcessingDTOMixin:
|
||||
def load_and_process_video(
|
||||
self: 'FileItemDTO',
|
||||
transform: Union[None, transforms.Compose],
|
||||
only_load_latents=False
|
||||
):
|
||||
if self.is_latent_cached:
|
||||
raise Exception('Latent caching not supported for videos')
|
||||
|
||||
if self.augments is not None and len(self.augments) > 0:
|
||||
raise Exception('Augments not supported for videos')
|
||||
|
||||
if self.has_augmentations:
|
||||
raise Exception('Augmentations not supported for videos')
|
||||
|
||||
if not self.dataset_config.buckets:
|
||||
raise Exception('Buckets required for video processing')
|
||||
|
||||
try:
|
||||
# Use OpenCV to capture video frames
|
||||
cap = cv2.VideoCapture(self.path)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise Exception(f"Failed to open video file: {self.path}")
|
||||
|
||||
# Get video properties
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
|
||||
# Calculate the max valid frame index (accounting for zero-indexing)
|
||||
max_frame_index = total_frames - 1
|
||||
|
||||
# Only log video properties if in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Video properties: {self.path}")
|
||||
print_acc(f" Total frames: {total_frames}")
|
||||
print_acc(f" Max valid frame index: {max_frame_index}")
|
||||
print_acc(f" FPS: {video_fps}")
|
||||
|
||||
frames_to_extract = []
|
||||
|
||||
# Always stretch/shrink to the requested number of frames if needed
|
||||
if self.dataset_config.shrink_video_to_frames or total_frames < self.dataset_config.num_frames:
|
||||
# Distribute frames evenly across the entire video
|
||||
interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0
|
||||
frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)]
|
||||
else:
|
||||
# Calculate frame interval based on FPS ratio
|
||||
fps_ratio = video_fps / self.dataset_config.fps
|
||||
frame_interval = max(1, int(round(fps_ratio)))
|
||||
|
||||
# Calculate max consecutive frames we can extract at desired FPS
|
||||
max_consecutive_frames = (total_frames // frame_interval)
|
||||
|
||||
if max_consecutive_frames < self.dataset_config.num_frames:
|
||||
# Not enough frames at desired FPS, so stretch instead
|
||||
interval = max_frame_index / (self.dataset_config.num_frames - 1) if self.dataset_config.num_frames > 1 else 0
|
||||
frames_to_extract = [min(int(round(i * interval)), max_frame_index) for i in range(self.dataset_config.num_frames)]
|
||||
else:
|
||||
# Calculate max start frame to ensure we can get all num_frames
|
||||
max_start_frame = max_frame_index - ((self.dataset_config.num_frames - 1) * frame_interval)
|
||||
start_frame = random.randint(0, max(0, max_start_frame))
|
||||
|
||||
# Generate list of frames to extract
|
||||
frames_to_extract = [start_frame + (i * frame_interval) for i in range(self.dataset_config.num_frames)]
|
||||
|
||||
# Final safety check - ensure no frame exceeds max valid index
|
||||
frames_to_extract = [min(frame_idx, max_frame_index) for frame_idx in frames_to_extract]
|
||||
|
||||
# Only log frames to extract if in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f" Frames to extract: {frames_to_extract}")
|
||||
|
||||
# Extract frames
|
||||
frames = []
|
||||
for frame_idx in frames_to_extract:
|
||||
# Safety check - ensure frame_idx is within bounds (silently fix)
|
||||
if frame_idx > max_frame_index:
|
||||
frame_idx = max_frame_index
|
||||
|
||||
# Set frame position
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||||
|
||||
# Silently verify position was set correctly (no warnings unless debug mode)
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
actual_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||||
if actual_pos != frame_idx:
|
||||
print_acc(f"Warning: Failed to set exact frame position. Requested: {frame_idx}, Actual: {actual_pos}")
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
# Try to provide more detailed error information
|
||||
actual_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||||
frame_pos_info = f"Requested frame: {frame_idx}, Actual frame position: {actual_frame}"
|
||||
|
||||
# Try to read the next available frame as a fallback
|
||||
fallback_success = False
|
||||
for fallback_offset in [1, -1, 5, -5, 10, -10]:
|
||||
fallback_pos = max(0, min(frame_idx + fallback_offset, max_frame_index))
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, fallback_pos)
|
||||
fallback_ret, fallback_frame = cap.read()
|
||||
if fallback_ret:
|
||||
# Only log in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Falling back to nearby frame {fallback_pos} instead of {frame_idx}")
|
||||
frame = fallback_frame
|
||||
fallback_success = True
|
||||
break
|
||||
else:
|
||||
# No fallback worked, raise a more detailed exception
|
||||
video_info = f"Video: {self.path}, Total frames: {total_frames}, FPS: {video_fps}"
|
||||
raise Exception(f"Failed to read frame {frame_idx} from video. {frame_pos_info}. {video_info}")
|
||||
|
||||
# Convert BGR to RGB
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Convert to PIL Image
|
||||
img = Image.fromarray(frame)
|
||||
|
||||
# Apply the same processing as for single images
|
||||
img = img.convert('RGB')
|
||||
|
||||
if self.flip_x:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
img = img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Apply bucketing
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
|
||||
# Apply transform if provided
|
||||
if transform:
|
||||
img = transform(img)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
# Release the video capture
|
||||
cap.release()
|
||||
|
||||
# Stack frames into tensor [frames, channels, height, width]
|
||||
self.tensor = torch.stack(frames)
|
||||
|
||||
# Only log success in debug mode
|
||||
if hasattr(self.dataset_config, 'debug') and self.dataset_config.debug:
|
||||
print_acc(f"Successfully loaded video with {len(frames)} frames: {self.path}")
|
||||
|
||||
except Exception as e:
|
||||
# Print full traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Provide more context about the error
|
||||
error_msg = str(e)
|
||||
try:
|
||||
if 'Failed to read frame' in error_msg and cap is not None:
|
||||
# Try to get more info about the video that failed
|
||||
cap_status = "Opened" if cap.isOpened() else "Closed"
|
||||
current_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) if cap.isOpened() else "Unknown"
|
||||
reported_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if cap.isOpened() else "Unknown"
|
||||
|
||||
print_acc(f"Video details when error occurred:")
|
||||
print_acc(f" Cap status: {cap_status}")
|
||||
print_acc(f" Current position: {current_pos}")
|
||||
print_acc(f" Reported total frames: {reported_total}")
|
||||
|
||||
# Try to verify if the video is corrupted
|
||||
if cap.isOpened():
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0) # Go to start
|
||||
start_ret, _ = cap.read()
|
||||
|
||||
# Try to read the last frame to check if it's accessible
|
||||
if reported_total > 0:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, reported_total - 1)
|
||||
end_ret, _ = cap.read()
|
||||
print_acc(f" Can read first frame: {start_ret}, Can read last frame: {end_ret}")
|
||||
|
||||
# Close the cap if it's still open
|
||||
cap.release()
|
||||
except Exception as debug_err:
|
||||
print_acc(f"Error during error diagnosis: {debug_err}")
|
||||
|
||||
print_acc(f"Error: {error_msg}")
|
||||
print_acc(f"Error loading video: {self.path}")
|
||||
|
||||
# Re-raise with more detailed information
|
||||
raise Exception(f"Video loading error ({self.path}): {error_msg}") from e
|
||||
|
||||
def load_and_process_image(
|
||||
self: 'FileItemDTO',
|
||||
transform: Union[None, transforms.Compose],
|
||||
only_load_latents=False
|
||||
):
|
||||
if self.dataset_config.num_frames > 1:
|
||||
self.load_and_process_video(transform, only_load_latents)
|
||||
return
|
||||
# if we are caching latents, just do that
|
||||
if self.is_latent_cached:
|
||||
self.get_latent()
|
||||
@@ -1379,6 +1574,8 @@ class LatentCachingMixin:
|
||||
self.latent_cache = {}
|
||||
|
||||
def cache_latents_all_latents(self: 'AiToolkitDataset'):
|
||||
if self.dataset_config.num_frames > 1:
|
||||
raise Exception("Error: caching latents is not supported for multi-frame datasets")
|
||||
with accelerator.main_process_first():
|
||||
print_acc(f"Caching latents for {self.dataset_path}")
|
||||
# cache all latents to disk
|
||||
@@ -1409,7 +1606,7 @@ class LatentCachingMixin:
|
||||
elif self.sd.model_config.is_pixart_sigma:
|
||||
file_item.latent_space_version = 'sdxl'
|
||||
else:
|
||||
file_item.latent_space_version = 'sd1'
|
||||
file_item.latent_space_version = self.sd.model_config.arch
|
||||
file_item.is_caching_to_disk = to_disk
|
||||
file_item.is_caching_to_memory = to_memory
|
||||
file_item.latent_load_device = self.sd.device
|
||||
|
||||
@@ -12,6 +12,7 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
from PIL import Image as PILImage
|
||||
|
||||
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
||||
|
||||
@@ -480,7 +481,26 @@ def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
|
||||
show_img(img_numpy[0], name=name)
|
||||
|
||||
def save_tensors(imgs: torch.Tensor, path='output.png'):
|
||||
if len(imgs.shape) == 5 and imgs.shape[0] == 1:
|
||||
imgs = imgs.squeeze(0)
|
||||
if len(imgs.shape) == 4:
|
||||
img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
|
||||
else:
|
||||
img_list = [imgs]
|
||||
|
||||
img = torch.cat(img_list, dim=3)
|
||||
img = img / 2 + 0.5
|
||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||
img_numpy = img_numpy.astype(np.uint8)
|
||||
# concat images to one
|
||||
img_numpy = np.concatenate(img_numpy, axis=1)
|
||||
# conver to pil
|
||||
img_pil = PILImage.fromarray(img_numpy)
|
||||
img_pil.save(path)
|
||||
|
||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||
if vae.device == 'cpu':
|
||||
|
||||
@@ -585,7 +585,6 @@ class Wan21(BaseModel):
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
latent_list = []
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(device)
|
||||
@@ -593,18 +592,43 @@ class Wan21(BaseModel):
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
|
||||
# We need to detect video if we have it.
|
||||
# videos come in (num_frames, channels, height, width)
|
||||
# images come in (channels, height, width)
|
||||
# we need to add a frame dimension to images and remap the video to (channels, num_frames, height, width)
|
||||
|
||||
if len(image_list[0].shape) == 3:
|
||||
image_list = [image.unsqueeze(1) for image in image_list]
|
||||
elif len(image_list[0].shape) == 4:
|
||||
image_list = [image.permute(1, 0, 2, 3) for image in image_list]
|
||||
else:
|
||||
raise ValueError(f"Image shape is not correct, got {list(image_list[0].shape)}")
|
||||
|
||||
VAE_SCALE_FACTOR = 8
|
||||
|
||||
# resize images if not divisible by 8
|
||||
# now we need to resize considering the shape (channels, num_frames, height, width)
|
||||
for i in range(len(image_list)):
|
||||
image = image_list[i]
|
||||
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
||||
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
|
||||
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
||||
if image.shape[2] % VAE_SCALE_FACTOR != 0 or image.shape[3] % VAE_SCALE_FACTOR != 0:
|
||||
# Create resized frames by handling each frame separately
|
||||
c, f, h, w = image.shape
|
||||
target_h = h // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
||||
target_w = w // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR
|
||||
|
||||
# We need to process each frame separately
|
||||
resized_frames = []
|
||||
for frame_idx in range(f):
|
||||
frame = image[:, frame_idx, :, :] # Extract single frame (channels, height, width)
|
||||
resized_frame = Resize((target_h, target_w))(frame)
|
||||
resized_frames.append(resized_frame.unsqueeze(1)) # Add frame dimension back
|
||||
|
||||
# Concatenate all frames back together along the frame dimension
|
||||
image_list[i] = torch.cat(resized_frames, dim=1)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
images = images.unsqueeze(2)
|
||||
# images = images.unsqueeze(2) # adds frame dimension so (bs, ch, h, w) -> (bs, ch, 1, h, w)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
|
||||
latents_mean = (
|
||||
|
||||
Reference in New Issue
Block a user