mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Added some features for an LCM condenser plugin
This commit is contained in:
@@ -656,9 +656,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
if self.train_config.noise_scheduler == 'lcm':
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch, original_inference_steps=1000
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
@@ -1136,6 +1141,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
|
||||
if os.path.exists(optimizer_state_file_path):
|
||||
# try to load
|
||||
# previous param groups
|
||||
# previous_params = copy.deepcopy(optimizer.param_groups)
|
||||
try:
|
||||
print(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path)
|
||||
@@ -1144,6 +1151,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print(e)
|
||||
|
||||
# Update the learning rates if they changed
|
||||
# optimizer.param_groups = previous_params
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
|
||||
Reference in New Issue
Block a user