Bug fixes, speed improvements, compatability adjustments withdiffusers updates

This commit is contained in:
Jaret Burkett
2023-09-13 07:03:53 -06:00
parent d8d1e6fd1e
commit ae70200d3c
8 changed files with 52 additions and 11 deletions

View File

@@ -1,8 +1,10 @@
import random
from datetime import datetime
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Union
import torch
import yaml
from jobs.process.BaseProcess import BaseProcess
@@ -28,6 +30,14 @@ class BaseTrainProcess(BaseProcess):
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
self.progress_bar: 'tqdm' = None
self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
# if training seed is set, use it
if self.training_seed is not None:
torch.manual_seed(self.training_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(self.training_seed)
random.seed(self.training_seed)
self.progress_bar = None
self.writer = None
self.training_folder = self.get_conf('training_folder',