Added methods to the dataloader to automatically generate controls for line, mask, inpainting, depth, and pose.

This commit is contained in:
Jaret Burkett
2025-04-09 13:35:04 -06:00
parent 615b0d0e94
commit 96ba2fd129
5 changed files with 205 additions and 11 deletions

View File

@@ -18,7 +18,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
@@ -372,7 +372,7 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
def __init__(
self,
@@ -394,6 +394,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
self.is_generating_controls = len(dataset_config.controls) > 0
self.epoch_num = 0
self.sd = sd
@@ -425,6 +426,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.caption_dict = json.load(f)
# keys are file paths
file_list = list(self.caption_dict.keys())
# remove items in the _controls_ folder
file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"]
if self.dataset_config.num_repeats > 1:
# repeat the list
@@ -548,6 +552,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.cache_latents_all_latents()
if self.is_caching_clip_vision_to_disk:
self.cache_clip_vision_to_disk()
if self.is_generating_controls:
# always do this last
self.setup_controls()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest