mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Fixed big issue with bucketing dataloader and added random cripping to a point of interest
This commit is contained in:
@@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
|
||||
@@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.setup_epoch()
|
||||
|
||||
def setup_epoch(self):
|
||||
# TODO: set this up to redo cropping and everything else
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
if self.epoch_num == 0:
|
||||
# initial setup
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
# setup buckets every epoch
|
||||
self.setup_buckets(quiet=True)
|
||||
self.epoch_num += 1
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
@@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# for buckets we collate ourselves for now
|
||||
# todo allow a scheduler to dynamically make buckets
|
||||
# we collate ourselves
|
||||
if len(self.batch_indices) - 1 < item:
|
||||
# tried everything to solve this. No way to reset length when redoing things. Pick another index
|
||||
item = random.randint(0, len(self.batch_indices) - 1)
|
||||
idx_list = self.batch_indices[item]
|
||||
return [self._get_single_item(idx) for idx in idx_list]
|
||||
else:
|
||||
@@ -523,3 +534,27 @@ def get_dataloader_from_datasets(
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
|
||||
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
|
||||
# hacky but needed because of different types of datasets and dataloaders
|
||||
dataloader.len = None
|
||||
if isinstance(dataloader.dataset, list):
|
||||
for dataset in dataloader.dataset:
|
||||
if hasattr(dataset, 'datasets'):
|
||||
for sub_dataset in dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
elif hasattr(dataset, 'setup_epoch'):
|
||||
dataset.setup_epoch()
|
||||
dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'setup_epoch'):
|
||||
dataloader.dataset.setup_epoch()
|
||||
dataloader.dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'datasets'):
|
||||
dataloader.dataset.len = None
|
||||
for sub_dataset in dataloader.dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
|
||||
Reference in New Issue
Block a user