mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added specialized scaler training to ip adapters
This commit is contained in:
@@ -1293,11 +1293,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
if self.adapter_config.train:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
# we have custom LR groups for IPAdapter
|
||||
adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr)
|
||||
for group in adapter_param_groups:
|
||||
params.append(group)
|
||||
else:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.adapter.enable_gradient_checkpointing()
|
||||
|
||||
Reference in New Issue
Block a user