mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a file signature check on the dataset size caching system to invalidate cached dimensions if the file changes.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
@@ -54,3 +55,12 @@ def adain(content_features, style_features):
|
||||
stylized_content = normalized_content * style_std + style_mean
|
||||
|
||||
return stylized_content
|
||||
|
||||
def get_quick_signature_string(file_path):
|
||||
try:
|
||||
file_stats = os.stat(file_path)
|
||||
# Combine size and mtime into a single string
|
||||
return f"{file_stats.st_size}:{int(file_stats.st_mtime)}"
|
||||
except Exception as e:
|
||||
print(f"Error accessing file {file_path}: {e}")
|
||||
return None
|
||||
@@ -458,7 +458,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
|
||||
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
|
||||
dataloader_version = "0.1.1"
|
||||
dataloader_version = "0.1.2"
|
||||
if os.path.exists(dataset_size_file):
|
||||
try:
|
||||
with open(dataset_size_file, 'r') as f:
|
||||
|
||||
@@ -10,6 +10,7 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.basic import get_quick_signature_string
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin, InpaintControlFileItemDTOMixin
|
||||
@@ -53,8 +54,19 @@ class FileItemDTO(
|
||||
file_key = self.path.replace(dataset_root, '')
|
||||
else:
|
||||
file_key = os.path.basename(self.path)
|
||||
|
||||
file_signature = get_quick_signature_string(self.path)
|
||||
if file_signature is None:
|
||||
raise Exception("Error: Could not get file signature for {self.path}")
|
||||
|
||||
use_db_entry = False
|
||||
if file_key in size_database:
|
||||
w, h = size_database[file_key]
|
||||
db_entry = size_database[file_key]
|
||||
if db_entry is not None and db_entry[2] == file_signature:
|
||||
use_db_entry = True
|
||||
|
||||
if use_db_entry:
|
||||
w, h, _ = size_database[file_key]
|
||||
elif self.is_video:
|
||||
# Open the video file
|
||||
video = cv2.VideoCapture(self.path)
|
||||
@@ -80,7 +92,7 @@ class FileItemDTO(
|
||||
# f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
w, h = img.size
|
||||
size_database[file_key] = (w, h)
|
||||
size_database[file_key] = (w, h, file_signature)
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
|
||||
|
||||
Reference in New Issue
Block a user