Added support for training vision direct weight adapters

This commit is contained in:
Jaret Burkett
2024-09-05 10:11:44 -06:00
parent 5c8fcc8a4e
commit 3a1f464132
3 changed files with 78 additions and 16 deletions

View File

@@ -1584,6 +1584,8 @@ class SDTrainer(BaseSDTrainProcess):
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.adapter and isinstance(self.adapter, CustomAdapter):
self.adapter.post_weight_update()
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()