mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-24 16:29:26 +00:00
Bug fix with dataloader. Added a flag to completly disable sampling
This commit is contained in:
@@ -1570,12 +1570,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested and self.step_num <= 1:
|
||||
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
@@ -1643,6 +1643,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_reg_step = False
|
||||
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
|
||||
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
|
||||
if self.train_config.disable_sampling:
|
||||
is_sample_step = False
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
@@ -1784,7 +1786,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num)
|
||||
if not self.train_config.disable_sampling:
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user