mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Base for loopback lora training setup, still working on proper sliders
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef
|
||||
|
||||
from jobs.process.BaseProcess import BaseProcess
|
||||
|
||||
|
||||
class BaseTrainProcess(BaseProcess):
|
||||
process_id: int
|
||||
config: OrderedDict
|
||||
progress_bar: ForwardRef('tqdm') = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -13,8 +17,23 @@ class BaseTrainProcess(BaseProcess):
|
||||
config: OrderedDict
|
||||
):
|
||||
super().__init__(process_id, job, config)
|
||||
self.progress_bar = None
|
||||
self.writer = self.job.writer
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.step = 0
|
||||
self.first_step = 0
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
# implement in child class
|
||||
# be sure to call super().run() first
|
||||
pass
|
||||
|
||||
# def print(self, message, **kwargs):
|
||||
def print(self, *args):
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.write(' '.join(map(str, args)))
|
||||
self.progress_bar.update()
|
||||
else:
|
||||
print(*args)
|
||||
|
||||
Reference in New Issue
Block a user