From af108bb964bdeba4b77872ae0fd766dddb7b79ca Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 12 Aug 2024 09:19:40 -0600 Subject: [PATCH] Bug fix with dataloader. Added a flag to completly disable sampling --- jobs/process/BaseSDTrainProcess.py | 9 ++++++--- toolkit/config_modules.py | 1 + toolkit/data_loader.py | 5 ++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d83c6081..474457d8 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1570,12 +1570,12 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_train_loop() - if self.has_first_sample_requested and self.step_num <= 1: + if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling: self.print("Generating first sample from first sample config") self.sample(0, is_first=True) # sample first - if self.train_config.skip_first_sample: + if self.train_config.skip_first_sample or self.train_config.disable_sampling: self.print("Skipping first sample due to config setting") elif self.step_num <= 1 or self.train_config.force_first_sample: self.print("Generating baseline samples before training") @@ -1643,6 +1643,8 @@ class BaseSDTrainProcess(BaseTrainProcess): is_reg_step = False is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0 is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0 + if self.train_config.disable_sampling: + is_sample_step = False # don't do a reg step on sample or save steps as we dont want to normalize on those if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: try: @@ -1784,7 +1786,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar.close() if self.train_config.free_u: self.sd.pipeline.disable_freeu() - self.sample(self.step_num) + if not self.train_config.disable_sampling: + self.sample(self.step_num) print("") self.save() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a8041c5a..561d97e8 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -357,6 +357,7 @@ class TrainConfig: self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) self.linear_timesteps = kwargs.get('linear_timesteps', False) + self.disable_sampling = kwargs.get('disable_sampling', False) class ModelConfig: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 8e5186dd..1a842dd2 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -437,7 +437,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti # this might take a while print(f"Dataset: {self.dataset_path}") print(f" - Preprocessing image dimensions") - dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json') + dataset_folder = self.dataset_path + if not os.path.isdir(self.dataset_path): + dataset_folder = os.path.dirname(dataset_folder) + dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') if os.path.exists(dataset_size_file): with open(dataset_size_file, 'r') as f: self.size_database = json.load(f)