mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Default to only training mse. Did a lot of cleanup with script. Added logging via tensorboard.
This commit is contained in:
@@ -28,6 +28,9 @@ class TrainJob(BaseJob):
|
||||
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||
self.logging_dir = self.get_conf('logging_dir', None)
|
||||
|
||||
self.writer = None
|
||||
self.setup_tensorboard()
|
||||
|
||||
# loads the processes from the config
|
||||
self.load_processes(process_dict)
|
||||
|
||||
@@ -38,3 +41,11 @@ class TrainJob(BaseJob):
|
||||
|
||||
for process in self.process:
|
||||
process.run()
|
||||
|
||||
def setup_tensorboard(self):
|
||||
if self.logging_dir:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=self.logging_dir,
|
||||
filename_suffix=f"_{self.name}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user