Added ability to load video datasets and train with them

This commit is contained in:
Jaret Burkett
2025-03-19 09:54:26 -06:00
parent fa187b1208
commit b829983b16
9 changed files with 340 additions and 74 deletions

View File

@@ -30,6 +30,10 @@ def is_native_windows():
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv']
class RescaleTransform:
@@ -376,8 +380,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
batch_size=1,
sd: 'StableDiffusion' = None,
):
super().__init__()
self.dataset_config = dataset_config
self.is_video = dataset_config.num_frames > 1
super().__init__()
folder_path = dataset_config.folder_path
self.dataset_path = dataset_config.dataset_path
if self.dataset_path is None:
@@ -407,7 +412,11 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
# check if dataset_path is a folder or json
if os.path.isdir(self.dataset_path):
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
extensions = image_extensions
if self.is_video:
# only look for videos
extensions = video_extensions
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions))]
else:
# assume json
with open(self.dataset_path, 'r') as f:
@@ -438,7 +447,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
# this might take a while
print_acc(f"Dataset: {self.dataset_path}")
print_acc(f" - Preprocessing image dimensions")
if self.is_video:
print_acc(f" - Preprocessing video dimensions")
else:
print_acc(f" - Preprocessing image dimensions")
dataset_folder = self.dataset_path
if not os.path.isdir(self.dataset_path):
dataset_folder = os.path.dirname(dataset_folder)
@@ -477,17 +489,23 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.file_list.append(file_item)
except Exception as e:
print_acc(traceback.format_exc())
print_acc(f"Error processing image: {file}")
if self.is_video:
print_acc(f"Error processing video: {file}")
else:
print_acc(f"Error processing image: {file}")
print_acc(e)
bad_count += 1
# save the size database
with open(dataset_size_file, 'w') as f:
json.dump(self.size_database, f)
print_acc(f" - Found {len(self.file_list)} images")
# print_acc(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
if self.is_video:
print_acc(f" - Found {len(self.file_list)} videos")
assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}"
else:
print_acc(f" - Found {len(self.file_list)} images")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
# handle x axis flips
if self.dataset_config.flip_x:
@@ -510,8 +528,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.file_list.append(new_file_item)
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print_acc(f" - Found {len(self.file_list)} images after adding flips")
if self.is_video:
print_acc(f" - Found {len(self.file_list)} videos after adding flips")
else:
print_acc(f" - Found {len(self.file_list)} images after adding flips")
self.setup_epoch()
@@ -539,7 +559,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
return len(self.file_list)
def _get_single_item(self, index) -> 'FileItemDTO':
file_item = copy.deepcopy(self.file_list[index])
file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index])
file_item.load_and_process_image(self.transform)
file_item.load_caption(self.caption_dict)
return file_item