Bug fixes. added ability to use l1 loss. varous other tests and improvements

This commit is contained in:
Jaret Burkett
2024-01-31 06:30:54 -07:00
parent 92b9c71d44
commit 1ae1017748
9 changed files with 474 additions and 23 deletions

View File

@@ -18,7 +18,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
@@ -355,7 +355,7 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(
self,
@@ -373,6 +373,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
self.epoch_num = 0
self.sd = sd
@@ -482,6 +483,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.setup_buckets()
if self.is_caching_latents:
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
@@ -611,3 +614,19 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader):
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
def get_dataloader_datasets(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
if isinstance(dataloader.dataset, list):
datasets = []
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
datasets.append(sub_dataset)
else:
datasets.append(dataset)
return datasets
elif hasattr(dataloader.dataset, 'datasets'):
return dataloader.dataset.datasets
else:
return [dataloader.dataset]