From 6d31c6db730a9cba3004eae1a3d7283c663d9295 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 11 Aug 2024 10:48:24 -0600 Subject: [PATCH] Added a fix for windows dataloader --- toolkit/data_loader.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 28099807..8e5186dd 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -583,6 +588,13 @@ def get_dataloader_from_datasets( # check if is caching latents + dataloader_kwargs = {} + + if is_native_windows(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor if has_buckets: # make sure they all have buckets @@ -595,16 +607,15 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=dataset_config_list[0].num_workers, - prefetch_factor=dataset_config_list[0].prefetch_factor, + **dataloader_kwargs ) else: data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=4, - collate_fn=dto_collation + collate_fn=dto_collation, + **dataloader_kwargs ) return data_loader