mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 00:39:22 +00:00
Added methods to the dataloader to automatically generate controls for line, mask, inpainting, depth, and pose.
This commit is contained in:
@@ -18,7 +18,7 @@ import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
@@ -372,7 +372,7 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
class AiToolkitDataset(LatentCachingMixin, ControlCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -394,6 +394,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
|
||||
self.is_generating_controls = len(dataset_config.controls) > 0
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
@@ -425,6 +426,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.caption_dict = json.load(f)
|
||||
# keys are file paths
|
||||
file_list = list(self.caption_dict.keys())
|
||||
|
||||
# remove items in the _controls_ folder
|
||||
file_list = [x for x in file_list if not os.path.basename(os.path.dirname(x)) == "_controls"]
|
||||
|
||||
if self.dataset_config.num_repeats > 1:
|
||||
# repeat the list
|
||||
@@ -548,6 +552,9 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.cache_latents_all_latents()
|
||||
if self.is_caching_clip_vision_to_disk:
|
||||
self.cache_clip_vision_to_disk()
|
||||
if self.is_generating_controls:
|
||||
# always do this last
|
||||
self.setup_controls()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
|
||||
Reference in New Issue
Block a user