diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 28099807..8e5186dd 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -583,6 +588,13 @@ def get_dataloader_from_datasets( # check if is caching latents + dataloader_kwargs = {} + + if is_native_windows(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor if has_buckets: # make sure they all have buckets @@ -595,16 +607,15 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=dataset_config_list[0].num_workers, - prefetch_factor=dataset_config_list[0].prefetch_factor, + **dataloader_kwargs ) else: data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=4, - collate_fn=dto_collation + collate_fn=dto_collation, + **dataloader_kwargs ) return data_loader