mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added additional config options for custom plugins I needed
This commit is contained in:
@@ -333,6 +333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
|
||||
# remove duplicates
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
|
||||
for item in items_to_remove:
|
||||
self.print(f"Removing old save: {item}")
|
||||
if os.path.isdir(item):
|
||||
@@ -758,7 +761,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
num_train_timesteps = self.train_config.num_train_timesteps
|
||||
|
||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||
# we store this value on our custom one
|
||||
@@ -791,14 +794,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if content_or_style == 'content':
|
||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
|
||||
elif content_or_style == 'style':
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
|
||||
|
||||
timestep_indices = value_map(
|
||||
timestep_indices,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.num_train_timesteps - 1,
|
||||
min_noise_steps,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
@@ -1234,6 +1237,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# load last saved weights
|
||||
if latest_save_path is not None:
|
||||
self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
|
||||
if self.embedding.step > 1:
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# self.step_num = self.embedding.step
|
||||
# self.start_step = self.step_num
|
||||
|
||||
Reference in New Issue
Block a user