mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 21:33:59 +00:00
Merge branch 'main' of github.com:ostris/ai-toolkit
This commit is contained in:
@@ -39,7 +39,7 @@ cd ai-toolkit
|
||||
git submodule update --init --recursive
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
pip install torch --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu118
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ config:
|
||||
shuffle_tokens: false # shuffle caption order, split by commas
|
||||
cache_latents_to_disk: true # leave this true unless you know what you're doing
|
||||
resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
|
||||
num_workers: 0
|
||||
train:
|
||||
batch_size: 1
|
||||
steps: 4000 # total number of steps to train 500 - 4000 is a good range
|
||||
|
||||
@@ -21,6 +21,11 @@ from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
import platform
|
||||
|
||||
def is_native_windows():
|
||||
return platform.system() == "Windows" and platform.release() != "2"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
@@ -583,6 +588,13 @@ def get_dataloader_from_datasets(
|
||||
|
||||
# check if is caching latents
|
||||
|
||||
dataloader_kwargs = {}
|
||||
|
||||
if is_native_windows():
|
||||
dataloader_kwargs['num_workers'] = 0
|
||||
else:
|
||||
dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers
|
||||
dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor
|
||||
|
||||
if has_buckets:
|
||||
# make sure they all have buckets
|
||||
@@ -595,16 +607,15 @@ def get_dataloader_from_datasets(
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=dataset_config_list[0].num_workers,
|
||||
prefetch_factor=dataset_config_list[0].prefetch_factor,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
collate_fn=dto_collation
|
||||
collate_fn=dto_collation,
|
||||
**dataloader_kwargs
|
||||
)
|
||||
return data_loader
|
||||
|
||||
|
||||
@@ -1960,7 +1960,8 @@ class StableDiffusion:
|
||||
else:
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||
latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor'])
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latents = latents * (self.vae.config['scaling_factor'] - shift)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
Reference in New Issue
Block a user