Bug fixes and improvements to token injection

This commit is contained in:
Jaret Burkett
2023-09-08 06:10:59 -06:00
parent 92a086d5a5
commit ce4f9fe02a
5 changed files with 74 additions and 63 deletions

View File

@@ -684,6 +684,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
flush()
self.lr_scheduler.step(self.step_num)
@@ -721,6 +722,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
# setup the networks to gradient checkpointing and everything works
if self.embedding is not None or self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.train()
else:
self.sd.text_encoder.train()
self.sd.unet.train()
with torch.no_grad():
if self.train_config.optimizer.lower().startswith('dadaptation') or \