mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Bug fixes, speed improvements, compatability adjustments withdiffusers updates
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
@@ -28,6 +30,14 @@ class BaseTrainProcess(BaseProcess):
|
||||
self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||
self.progress_bar: 'tqdm' = None
|
||||
|
||||
self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
|
||||
# if training seed is set, use it
|
||||
if self.training_seed is not None:
|
||||
torch.manual_seed(self.training_seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(self.training_seed)
|
||||
random.seed(self.training_seed)
|
||||
|
||||
self.progress_bar = None
|
||||
self.writer = None
|
||||
self.training_folder = self.get_conf('training_folder',
|
||||
|
||||
Reference in New Issue
Block a user