Bug fix with dataloader. Added a flag to completly disable sampling

This commit is contained in:
Jaret Burkett
2024-08-12 09:19:40 -06:00
parent 89d61a3b8e
commit af108bb964
3 changed files with 11 additions and 4 deletions

View File

@@ -1570,12 +1570,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_train_loop()
if self.has_first_sample_requested and self.step_num <= 1:
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
self.print("Generating first sample from first sample config")
self.sample(0, is_first=True)
# sample first
if self.train_config.skip_first_sample:
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
self.print("Skipping first sample due to config setting")
elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training")
@@ -1643,6 +1643,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_reg_step = False
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
if self.train_config.disable_sampling:
is_sample_step = False
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
@@ -1784,7 +1786,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num)
if not self.train_config.disable_sampling:
self.sample(self.step_num)
print("")
self.save()

View File

@@ -357,6 +357,7 @@ class TrainConfig:
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.disable_sampling = kwargs.get('disable_sampling', False)
class ModelConfig:

View File

@@ -437,7 +437,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
# this might take a while
print(f"Dataset: {self.dataset_path}")
print(f" - Preprocessing image dimensions")
dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json')
dataset_folder = self.dataset_path
if not os.path.isdir(self.dataset_path):
dataset_folder = os.path.dirname(dataset_folder)
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
if os.path.exists(dataset_size_file):
with open(dataset_size_file, 'r') as f:
self.size_database = json.load(f)