Memory optimizations. Default to using cudamalloc when torch 2.0 for mem allocation

This commit is contained in:
Jaret Burkett
2023-09-12 04:30:23 -06:00
parent e8583860ad
commit d74dd636ee
5 changed files with 104 additions and 5 deletions

View File

@@ -813,6 +813,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
# flush every 10 steps
if self.step_num % 10 == 0:
flush()
self.progress_bar.close()
self.sample(self.step_num + 1)
print("")