Added some features for an LCM condenser plugin

This commit is contained in:
Jaret Burkett
2023-11-15 08:56:45 -07:00
parent 4f9cdd916a
commit e47006ed70
4 changed files with 80 additions and 16 deletions

View File

@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'):
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
if self.train_config.noise_scheduler == 'lcm':
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch, original_inference_steps=1000
)
else:
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
if os.path.exists(optimizer_state_file_path):
# try to load
# previous param groups
# previous_params = copy.deepcopy(optimizer.param_groups)
try:
print(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path)
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
print(e)
# Update the learning rates if they changed
# optimizer.param_groups = previous_params
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum