diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d5ab38b6..4eba8ff4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -524,6 +524,9 @@ class DatasetConfig: self.replacements: List[str] = kwargs.get('replacements', []) self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0) + self.num_workers: int = kwargs.get('num_workers', 4) + self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) + def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: """ diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 28e1c5ce..a73fb001 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -585,7 +585,8 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=8 + num_workers=dataset_config_list[0].num_workers, + prefetch_factor=dataset_config_list[0].prefetch_factor, ) else: data_loader = DataLoader( diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index abc846af..02c5e327 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -220,11 +220,7 @@ class DataLoaderBatchDTO: to_replace_list=None, add_if_not_present=True ): - return [x.get_caption( - trigger=trigger, - to_replace_list=to_replace_list, - add_if_not_present=add_if_not_present - ) for x in self.file_items] + return [x.caption for x in self.file_items] def get_caption_short_list( self, @@ -232,12 +228,7 @@ class DataLoaderBatchDTO: to_replace_list=None, add_if_not_present=True ): - return [x.get_caption( - trigger=trigger, - to_replace_list=to_replace_list, - add_if_not_present=add_if_not_present, - short_caption=True - ) for x in self.file_items] + return [x.caption_short for x in self.file_items] def cleanup(self): del self.latents diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index ccaf6bc9..a0007d50 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -263,6 +263,8 @@ class CaptionProcessingDTOMixin: super().__init__(*args, **kwargs) self.raw_caption: str = None self.raw_caption_short: str = None + self.caption: str = None + self.caption_short: str = None # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): @@ -308,6 +310,10 @@ class CaptionProcessingDTOMixin: self.raw_caption = prompt self.raw_caption_short = short_caption + self.caption = self.get_caption() + if self.raw_caption_short is not None: + self.caption_short = self.get_caption(short_caption=True) + def get_caption( self: 'FileItemDTO', trigger=None, @@ -367,12 +373,13 @@ class CaptionProcessingDTOMixin: num_triggers = random.randint(0, num_triggers) if num_triggers > 0: + triggers = random.sample(self.dataset_config.random_triggers, num_triggers) + caption = caption + ', ' + ', '.join(triggers) # add random triggers - for i in range(num_triggers): - - - - caption = caption + ', ' + random.choice(self.dataset_config.random_triggers) + # for i in range(num_triggers): + # # fastest method + # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] + # caption = caption + ', ' + trigger if self.dataset_config.shuffle_tokens: # shuffle again diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index b99569e7..cf7225b7 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -330,6 +330,7 @@ class StableDiffusion: requires_safety_checker=False, safety_checker=None, variant="fp16", + trust_remote_code=True, **load_args ).to(self.device_torch) else: @@ -341,6 +342,7 @@ class StableDiffusion: requires_safety_checker=False, torch_dtype=self.torch_dtype, safety_checker=None, + trust_remote_code=True, **load_args ).to(self.device_torch) flush()