From 8d48ad4e85874b3dca610a659e3e70233bc38ebc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 11 Aug 2024 10:28:39 -0600 Subject: [PATCH 1/3] fixed bug I added to demo config --- README.md | 2 +- config/examples/train_lora_flux_24gb.yaml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index d2c21dc4..f60597cc 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ cd ai-toolkit git submodule update --init --recursive python -m venv venv .\venv\Scripts\activate -pip install torch --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 pip install -r requirements.txt ``` diff --git a/config/examples/train_lora_flux_24gb.yaml b/config/examples/train_lora_flux_24gb.yaml index 74eceae9..8f862848 100644 --- a/config/examples/train_lora_flux_24gb.yaml +++ b/config/examples/train_lora_flux_24gb.yaml @@ -31,7 +31,6 @@ config: shuffle_tokens: false # shuffle caption order, split by commas cache_latents_to_disk: true # leave this true unless you know what you're doing resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions - num_workers: 0 train: batch_size: 1 steps: 4000 # total number of steps to train 500 - 4000 is a good range From 6490a326e5b82c4773290c7df82b44b8491e5211 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 11 Aug 2024 10:30:55 -0600 Subject: [PATCH 2/3] Fixed issue for vaes without a shift --- toolkit/stable_diffusion_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a5f59693..35c0a37c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1960,7 +1960,8 @@ class StableDiffusion: else: latents = self.vae.encode(images).latent_dist.sample() # latents = self.vae.encode(images, return_dict=False)[0] - latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor']) + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latents = latents * (self.vae.config['scaling_factor'] - shift) latents = latents.to(device, dtype=dtype) return latents From 6d31c6db730a9cba3004eae1a3d7283c663d9295 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 11 Aug 2024 10:48:24 -0600 Subject: [PATCH 3/3] Added a fix for windows dataloader --- toolkit/data_loader.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 28099807..8e5186dd 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +import platform + +def is_native_windows(): + return platform.system() == "Windows" and platform.release() != "2" + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -583,6 +588,13 @@ def get_dataloader_from_datasets( # check if is caching latents + dataloader_kwargs = {} + + if is_native_windows(): + dataloader_kwargs['num_workers'] = 0 + else: + dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers + dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor if has_buckets: # make sure they all have buckets @@ -595,16 +607,15 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=dataset_config_list[0].num_workers, - prefetch_factor=dataset_config_list[0].prefetch_factor, + **dataloader_kwargs ) else: data_loader = DataLoader( concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=4, - collate_fn=dto_collation + collate_fn=dto_collation, + **dataloader_kwargs ) return data_loader