Default to only training mse. Did a lot of cleanup with script. Added logging via tensorboard.

This commit is contained in:
Jaret Burkett
2023-07-18 09:40:51 -06:00
parent 94d52572d4
commit 17c13eef88
3 changed files with 162 additions and 77 deletions

View File

@@ -28,6 +28,9 @@ class TrainJob(BaseJob):
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
self.logging_dir = self.get_conf('logging_dir', None)
self.writer = None
self.setup_tensorboard()
# loads the processes from the config
self.load_processes(process_dict)
@@ -38,3 +41,11 @@ class TrainJob(BaseJob):
for process in self.process:
process.run()
def setup_tensorboard(self):
if self.logging_dir:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(
log_dir=self.logging_dir,
filename_suffix=f"_{self.name}"
)