Bug fixes. Added some functionality to help with private extensions

This commit is contained in:
Jaret Burkett
2023-10-05 07:09:34 -06:00
parent 579650eaf8
commit f73402473b
8 changed files with 99 additions and 20 deletions

View File

@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.trigger_word is not None:
# just so auto1111 will pick it up
o_dict['ss_tag_frequency'] = {
'actfig': {
'actfig': 1
[self.trigger_word ]: {
[self.trigger_word ]: 1
}
}
@@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if not params:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
# we may be using it for prompt injections
if self.adapter_config is not None:
self.setup_adapter()