mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 15:07:22 +00:00
Bug fixes. added ability to use l1 loss. varous other tests and improvements
This commit is contained in:
@@ -18,7 +18,7 @@ import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -355,7 +355,7 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -373,6 +373,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
@@ -482,6 +483,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
if self.is_caching_clip_vision_to_disk:
|
||||
self.cache_clip_vision_to_disk()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
@@ -611,3 +614,19 @@ def trigger_dataloader_setup_epoch(dataloader: DataLoader):
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
|
||||
def get_dataloader_datasets(dataloader: DataLoader):
|
||||
# hacky but needed because of different types of datasets and dataloaders
|
||||
if isinstance(dataloader.dataset, list):
|
||||
datasets = []
|
||||
for dataset in dataloader.dataset:
|
||||
if hasattr(dataset, 'datasets'):
|
||||
for sub_dataset in dataset.datasets:
|
||||
datasets.append(sub_dataset)
|
||||
else:
|
||||
datasets.append(dataset)
|
||||
return datasets
|
||||
elif hasattr(dataloader.dataset, 'datasets'):
|
||||
return dataloader.dataset.datasets
|
||||
else:
|
||||
return [dataloader.dataset]
|
||||
|
||||
Reference in New Issue
Block a user