mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fix with dataloader. Added a flag to completly disable sampling
This commit is contained in:
@@ -1570,12 +1570,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
self.hook_before_train_loop()
|
||||
|
||||
if self.has_first_sample_requested and self.step_num <= 1:
|
||||
if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
|
||||
self.print("Generating first sample from first sample config")
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
if self.train_config.skip_first_sample or self.train_config.disable_sampling:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
@@ -1643,6 +1643,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_reg_step = False
|
||||
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
|
||||
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
|
||||
if self.train_config.disable_sampling:
|
||||
is_sample_step = False
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
@@ -1784,7 +1786,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num)
|
||||
if not self.train_config.disable_sampling:
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
|
||||
@@ -357,6 +357,7 @@ class TrainConfig:
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||
self.disable_sampling = kwargs.get('disable_sampling', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -437,7 +437,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
# this might take a while
|
||||
print(f"Dataset: {self.dataset_path}")
|
||||
print(f" - Preprocessing image dimensions")
|
||||
dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json')
|
||||
dataset_folder = self.dataset_path
|
||||
if not os.path.isdir(self.dataset_path):
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
|
||||
if os.path.exists(dataset_size_file):
|
||||
with open(dataset_size_file, 'r') as f:
|
||||
self.size_database = json.load(f)
|
||||
|
||||
Reference in New Issue
Block a user