diff --git a/README.md b/README.md index d2c21dc4..f60597cc 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ cd ai-toolkit git submodule update --init --recursive python -m venv venv .\venv\Scripts\activate -pip install torch --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 pip install -r requirements.txt ``` diff --git a/config/examples/train_lora_flux_24gb.yaml b/config/examples/train_lora_flux_24gb.yaml index 74eceae9..8f862848 100644 --- a/config/examples/train_lora_flux_24gb.yaml +++ b/config/examples/train_lora_flux_24gb.yaml @@ -31,7 +31,6 @@ config: shuffle_tokens: false # shuffle caption order, split by commas cache_latents_to_disk: true # leave this true unless you know what you're doing resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions - num_workers: 0 train: batch_size: 1 steps: 4000 # total number of steps to train 500 - 4000 is a good range diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 28099807..8e5186dd 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -583,6 +588,13 @@ def get_dataloader_from_datasets( # check if is caching latents + dataloader_kwargs = {} + + if is_native_windows(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor if has_buckets: # make sure they all have buckets @@ -595,16 +607,15 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=dataset_config_list[0].num_workers, - prefetch_factor=dataset_config_list[0].prefetch_factor, + **dataloader_kwargs ) else: data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=4, - collate_fn=dto_collation + collate_fn=dto_collation, + **dataloader_kwargs ) return data_loader diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a5f59693..35c0a37c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1960,7 +1960,8 @@ class StableDiffusion: else: latents = self.vae.encode(images).latent_dist.sample() # latents = self.vae.encode(images, return_dict=False)[0] - latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor']) + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latents = latents * (self.vae.config['scaling_factor'] - shift) latents = latents.to(device, dtype=dtype) return latents