mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Added ability to load video datasets and train with them
This commit is contained in:
@@ -30,6 +30,10 @@ def is_native_windows():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
video_extensions = ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.wmv', '.m4v', '.flv']
|
||||
|
||||
|
||||
class RescaleTransform:
|
||||
@@ -376,8 +380,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
batch_size=1,
|
||||
sd: 'StableDiffusion' = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
self.is_video = dataset_config.num_frames > 1
|
||||
super().__init__()
|
||||
folder_path = dataset_config.folder_path
|
||||
self.dataset_path = dataset_config.dataset_path
|
||||
if self.dataset_path is None:
|
||||
@@ -407,7 +412,11 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# check if dataset_path is a folder or json
|
||||
if os.path.isdir(self.dataset_path):
|
||||
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
extensions = image_extensions
|
||||
if self.is_video:
|
||||
# only look for videos
|
||||
extensions = video_extensions
|
||||
file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(tuple(extensions))]
|
||||
else:
|
||||
# assume json
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
@@ -438,7 +447,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# this might take a while
|
||||
print_acc(f"Dataset: {self.dataset_path}")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
if self.is_video:
|
||||
print_acc(f" - Preprocessing video dimensions")
|
||||
else:
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
dataset_folder = self.dataset_path
|
||||
if not os.path.isdir(self.dataset_path):
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
@@ -477,17 +489,23 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print_acc(traceback.format_exc())
|
||||
print_acc(f"Error processing image: {file}")
|
||||
if self.is_video:
|
||||
print_acc(f"Error processing video: {file}")
|
||||
else:
|
||||
print_acc(f"Error processing image: {file}")
|
||||
print_acc(e)
|
||||
bad_count += 1
|
||||
|
||||
# save the size database
|
||||
with open(dataset_size_file, 'w') as f:
|
||||
json.dump(self.size_database, f)
|
||||
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
# print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
if self.is_video:
|
||||
print_acc(f" - Found {len(self.file_list)} videos")
|
||||
assert len(self.file_list) > 0, f"no videos found in {self.dataset_path}"
|
||||
else:
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
@@ -510,8 +528,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
if self.is_video:
|
||||
print_acc(f" - Found {len(self.file_list)} videos after adding flips")
|
||||
else:
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
@@ -539,7 +559,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
return len(self.file_list)
|
||||
|
||||
def _get_single_item(self, index) -> 'FileItemDTO':
|
||||
file_item = copy.deepcopy(self.file_list[index])
|
||||
file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index])
|
||||
file_item.load_and_process_image(self.transform)
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
Reference in New Issue
Block a user