mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-01 16:49:56 +00:00
Improvements to dataloader
This commit is contained in:
@@ -524,6 +524,9 @@ class DatasetConfig:
|
||||
self.replacements: List[str] = kwargs.get('replacements', [])
|
||||
self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
|
||||
|
||||
self.num_workers: int = kwargs.get('num_workers', 4)
|
||||
self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
|
||||
@@ -585,7 +585,8 @@ def get_dataloader_from_datasets(
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=8
|
||||
num_workers=dataset_config_list[0].num_workers,
|
||||
prefetch_factor=dataset_config_list[0].prefetch_factor,
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
|
||||
@@ -220,11 +220,7 @@ class DataLoaderBatchDTO:
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present
|
||||
) for x in self.file_items]
|
||||
return [x.caption for x in self.file_items]
|
||||
|
||||
def get_caption_short_list(
|
||||
self,
|
||||
@@ -232,12 +228,7 @@ class DataLoaderBatchDTO:
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present,
|
||||
short_caption=True
|
||||
) for x in self.file_items]
|
||||
return [x.caption_short for x in self.file_items]
|
||||
|
||||
def cleanup(self):
|
||||
del self.latents
|
||||
|
||||
@@ -263,6 +263,8 @@ class CaptionProcessingDTOMixin:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.raw_caption: str = None
|
||||
self.raw_caption_short: str = None
|
||||
self.caption: str = None
|
||||
self.caption_short: str = None
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
@@ -308,6 +310,10 @@ class CaptionProcessingDTOMixin:
|
||||
self.raw_caption = prompt
|
||||
self.raw_caption_short = short_caption
|
||||
|
||||
self.caption = self.get_caption()
|
||||
if self.raw_caption_short is not None:
|
||||
self.caption_short = self.get_caption(short_caption=True)
|
||||
|
||||
def get_caption(
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
@@ -367,12 +373,13 @@ class CaptionProcessingDTOMixin:
|
||||
num_triggers = random.randint(0, num_triggers)
|
||||
|
||||
if num_triggers > 0:
|
||||
triggers = random.sample(self.dataset_config.random_triggers, num_triggers)
|
||||
caption = caption + ', ' + ', '.join(triggers)
|
||||
# add random triggers
|
||||
for i in range(num_triggers):
|
||||
|
||||
|
||||
|
||||
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||
# for i in range(num_triggers):
|
||||
# # fastest method
|
||||
# trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))]
|
||||
# caption = caption + ', ' + trigger
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
# shuffle again
|
||||
|
||||
@@ -330,6 +330,7 @@ class StableDiffusion:
|
||||
requires_safety_checker=False,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
else:
|
||||
@@ -341,6 +342,7 @@ class StableDiffusion:
|
||||
requires_safety_checker=False,
|
||||
torch_dtype=self.torch_dtype,
|
||||
safety_checker=None,
|
||||
trust_remote_code=True,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
flush()
|
||||
|
||||
Reference in New Issue
Block a user