mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bug fixes. Added some functionality to help with private extensions
This commit is contained in:
@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.trigger_word is not None:
|
||||
# just so auto1111 will pick it up
|
||||
o_dict['ss_tag_frequency'] = {
|
||||
'actfig': {
|
||||
'actfig': 1
|
||||
[self.trigger_word ]: {
|
||||
[self.trigger_word ]: 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -827,14 +827,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
|
||||
Reference in New Issue
Block a user