mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 17:49:49 +00:00
Bug fixes
This commit is contained in:
@@ -282,12 +282,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
# TODO handle dreambooth, fine tuning, etc
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
metadata=save_meta
|
||||
)
|
||||
self.network.multiplier = prev_multiplier
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -639,7 +642,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop()
|
||||
|
||||
if self.train_config.optimizer.startswith('dadaptation'):
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
|
||||
Reference in New Issue
Block a user