Bug fixes

This commit is contained in:
Jaret Burkett
2023-07-29 13:39:57 -06:00
parent 2305e55c82
commit 9cdf2dd6e4
2 changed files with 12 additions and 8 deletions

View File

@@ -282,12 +282,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
# TODO handle dreambooth, fine tuning, etc
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
metadata=save_meta
)
self.network.multiplier = prev_multiplier
else:
self.sd.save(
file_path,
@@ -639,7 +642,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
loss_dict = self.hook_train_loop()
if self.train_config.optimizer.startswith('dadaptation'):
if self.train_config.optimizer.lower().startswith('dadaptation'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]