mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added a fix for windows dataloader
This commit is contained in:
@@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
|||||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||||
|
|
||||||
|
import platform
|
||||||
|
|
||||||
|
def is_native_windows():
|
||||||
|
return platform.system() == "Windows" and platform.release() != "2"
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
@@ -583,6 +588,13 @@ def get_dataloader_from_datasets(
|
|||||||
|
|
||||||
# check if is caching latents
|
# check if is caching latents
|
||||||
|
|
||||||
|
dataloader_kwargs = {}
|
||||||
|
|
||||||
|
if is_native_windows():
|
||||||
|
dataloader_kwargs['num_workers'] = 0
|
||||||
|
else:
|
||||||
|
dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers
|
||||||
|
dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor
|
||||||
|
|
||||||
if has_buckets:
|
if has_buckets:
|
||||||
# make sure they all have buckets
|
# make sure they all have buckets
|
||||||
@@ -595,16 +607,15 @@ def get_dataloader_from_datasets(
|
|||||||
drop_last=False,
|
drop_last=False,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=dto_collation, # Use the custom collate function
|
collate_fn=dto_collation, # Use the custom collate function
|
||||||
num_workers=dataset_config_list[0].num_workers,
|
**dataloader_kwargs
|
||||||
prefetch_factor=dataset_config_list[0].prefetch_factor,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
concatenated_dataset,
|
concatenated_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
collate_fn=dto_collation,
|
||||||
collate_fn=dto_collation
|
**dataloader_kwargs
|
||||||
)
|
)
|
||||||
return data_loader
|
return data_loader
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user