Fixed big issue with bucketing dataloader and added random cripping to a point of interest

This commit is contained in:
Jaret Burkett
2023-10-02 18:31:08 -06:00
parent 320e109c5f
commit 579650eaf8
6 changed files with 264 additions and 72 deletions

View File

@@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.epoch_num = 0
self.sd = sd
@@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.setup_epoch()
def setup_epoch(self):
# TODO: set this up to redo cropping and everything else
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
if self.epoch_num == 0:
# initial setup
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
# setup buckets every epoch
self.setup_buckets(quiet=True)
self.epoch_num += 1
def __len__(self):
if self.dataset_config.buckets:
@@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets
# we collate ourselves
if len(self.batch_indices) - 1 < item:
# tried everything to solve this. No way to reset length when redoing things. Pick another index
item = random.randint(0, len(self.batch_indices) - 1)
idx_list = self.batch_indices[item]
return [self._get_single_item(idx) for idx in idx_list]
else:
@@ -523,3 +534,27 @@ def get_dataloader_from_datasets(
collate_fn=dto_collation
)
return data_loader
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
dataloader.len = None
if isinstance(dataloader.dataset, list):
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
elif hasattr(dataset, 'setup_epoch'):
dataset.setup_epoch()
dataset.len = None
elif hasattr(dataloader.dataset, 'setup_epoch'):
dataloader.dataset.setup_epoch()
dataloader.dataset.len = None
elif hasattr(dataloader.dataset, 'datasets'):
dataloader.dataset.len = None
for sub_dataset in dataloader.dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None