Added a file signature check on the dataset size caching system to invalidate cached dimensions if the file changes.

This commit is contained in:
Jaret Burkett
2025-04-01 07:39:36 -06:00
parent 5ea19b6292
commit 3d131fb27a
3 changed files with 25 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
import gc
import os
import torch
@@ -54,3 +55,12 @@ def adain(content_features, style_features):
stylized_content = normalized_content * style_std + style_mean
return stylized_content
def get_quick_signature_string(file_path):
try:
file_stats = os.stat(file_path)
# Combine size and mtime into a single string
return f"{file_stats.st_size}:{int(file_stats.st_mtime)}"
except Exception as e:
print(f"Error accessing file {file_path}: {e}")
return None

View File

@@ -458,7 +458,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
dataset_folder = os.path.dirname(dataset_folder)
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
dataloader_version = "0.1.1"
dataloader_version = "0.1.2"
if os.path.exists(dataset_size_file):
try:
with open(dataset_size_file, 'r') as f:

View File

@@ -10,6 +10,7 @@ from PIL import Image
from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.basic import get_quick_signature_string
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
@@ -53,8 +54,19 @@ class FileItemDTO(
file_key = self.path.replace(dataset_root, '')
else:
file_key = os.path.basename(self.path)
file_signature = get_quick_signature_string(self.path)
if file_signature is None:
raise Exception("Error: Could not get file signature for {self.path}")
use_db_entry = False
if file_key in size_database:
w, h = size_database[file_key]
db_entry = size_database[file_key]
if db_entry is not None and db_entry[2] == file_signature:
use_db_entry = True
if use_db_entry:
w, h, _ = size_database[file_key]
elif self.is_video:
# Open the video file
video = cv2.VideoCapture(self.path)
@@ -80,7 +92,7 @@ class FileItemDTO(
# f'This process is faster for png, jpeg')
img = exif_transpose(Image.open(self.path))
w, h = img.size
size_database[file_key] = (w, h)
size_database[file_key] = (w, h, file_signature)
self.width: int = w
self.height: int = h
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)