Bug fix with dataloader. Added a flag to completly disable sampling

This commit is contained in:
Jaret Burkett
2024-08-12 09:19:40 -06:00
parent 89d61a3b8e
commit af108bb964
3 changed files with 11 additions and 4 deletions

View File

@@ -1570,12 +1570,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_train_loop()
if self.has_first_sample_requested and self.step_num <= 1:
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
self.print("Generating first sample from first sample config")
self.sample(0, is_first=True)
# sample first
if self.train_config.skip_first_sample:
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
self.print("Skipping first sample due to config setting")
elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training")
@@ -1643,6 +1643,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_reg_step = False
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
if self.train_config.disable_sampling:
is_sample_step = False
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
@@ -1784,7 +1786,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num)
if not self.train_config.disable_sampling:
self.sample(self.step_num)
print("")
self.save()