Added additional config options for custom plugins I needed

This commit is contained in:
Jaret Burkett
2024-01-15 08:31:09 -07:00
parent e190fbaeb8
commit 5276975fb0
7 changed files with 37 additions and 31 deletions

View File

@@ -333,6 +333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# remove all but the latest max_step_saves_to_keep
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
# remove duplicates
items_to_remove = list(dict.fromkeys(items_to_remove))
for item in items_to_remove:
self.print(f"Removing old save: {item}")
if os.path.isdir(item):
@@ -758,7 +761,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
do_double = False
with self.timer('prepare_noise'):
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
num_train_timesteps = self.train_config.num_train_timesteps
if self.train_config.noise_scheduler in ['custom_lcm']:
# we store this value on our custom one
@@ -791,14 +794,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
orig_timesteps = torch.rand((batch_size,), device=latents.device)
if content_or_style == 'content':
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
elif content_or_style == 'style':
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
timestep_indices = value_map(
timestep_indices,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.num_train_timesteps - 1,
min_noise_steps,
max_noise_steps - 1
)
@@ -1234,6 +1237,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# load last saved weights
if latest_save_path is not None:
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
if self.embedding.step > 1:
self.step_num = self.embedding.step
self.start_step = self.step_num
# self.step_num = self.embedding.step
# self.start_step = self.step_num