Improvements to dataloader

This commit is contained in:
Jaret Burkett
2024-04-27 09:28:28 -06:00
parent 5da3613e0b
commit b96913d73c
5 changed files with 21 additions and 17 deletions

View File

@@ -524,6 +524,9 @@ class DatasetConfig:
self.replacements: List[str] = kwargs.get('replacements', [])
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
self.num_workers: int = kwargs.get('num_workers', 4)
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
"""

View File

@@ -585,7 +585,8 @@ def get_dataloader_from_datasets(
drop_last=False,
shuffle=True,
collate_fn=dto_collation, # Use the custom collate function
num_workers=8
num_workers=dataset_config_list[0].num_workers,
prefetch_factor=dataset_config_list[0].prefetch_factor,
)
else:
data_loader = DataLoader(

View File

@@ -220,11 +220,7 @@ class DataLoaderBatchDTO:
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present
) for x in self.file_items]
return [x.caption for x in self.file_items]
def get_caption_short_list(
self,
@@ -232,12 +228,7 @@ class DataLoaderBatchDTO:
to_replace_list=None,
add_if_not_present=True
):
return [x.get_caption(
trigger=trigger,
to_replace_list=to_replace_list,
add_if_not_present=add_if_not_present,
short_caption=True
) for x in self.file_items]
return [x.caption_short for x in self.file_items]
def cleanup(self):
del self.latents

View File

@@ -263,6 +263,8 @@ class CaptionProcessingDTOMixin:
super().__init__(*args, **kwargs)
self.raw_caption: str = None
self.raw_caption_short: str = None
self.caption: str = None
self.caption_short: str = None
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
@@ -308,6 +310,10 @@ class CaptionProcessingDTOMixin:
self.raw_caption = prompt
self.raw_caption_short = short_caption
self.caption = self.get_caption()
if self.raw_caption_short is not None:
self.caption_short = self.get_caption(short_caption=True)
def get_caption(
self: 'FileItemDTO',
trigger=None,
@@ -367,12 +373,13 @@ class CaptionProcessingDTOMixin:
num_triggers = random.randint(0, num_triggers)
if num_triggers > 0:
triggers = random.sample(self.dataset_config.random_triggers, num_triggers)
caption = caption + ', ' + ', '.join(triggers)
# add random triggers
for i in range(num_triggers):
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
# for i in range(num_triggers):
# # fastest method
# trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))]
# caption = caption + ', ' + trigger
if self.dataset_config.shuffle_tokens:
# shuffle again

View File

@@ -330,6 +330,7 @@ class StableDiffusion:
requires_safety_checker=False,
safety_checker=None,
variant="fp16",
trust_remote_code=True,
**load_args
).to(self.device_torch)
else:
@@ -341,6 +342,7 @@ class StableDiffusion:
requires_safety_checker=False,
torch_dtype=self.torch_dtype,
safety_checker=None,
trust_remote_code=True,
**load_args
).to(self.device_torch)
flush()