Added specialized scaler training to ip adapters

This commit is contained in:
Jaret Burkett
2024-04-05 08:17:09 -06:00
parent 427847ac4c
commit 7284aab7c0
7 changed files with 182 additions and 29 deletions

View File

@@ -1293,11 +1293,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.adapter_config is not None:
self.setup_adapter()
if self.adapter_config.train:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
if isinstance(self.adapter, IPAdapter):
# we have custom LR groups for IPAdapter
adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr)
for group in adapter_param_groups:
params.append(group)
else:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()