mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Moved some of the job config into base process so it will be easier to extend extensions
This commit is contained in:
@@ -21,6 +21,7 @@ process_dict = {
|
|||||||
'lora_hack': 'TrainLoRAHack',
|
'lora_hack': 'TrainLoRAHack',
|
||||||
'rescale_sd': 'TrainSDRescaleProcess',
|
'rescale_sd': 'TrainSDRescaleProcess',
|
||||||
'esrgan': 'TrainESRGANProcess',
|
'esrgan': 'TrainESRGANProcess',
|
||||||
|
'reference': 'TrainReferenceProcess',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -36,18 +37,9 @@ class TrainJob(BaseJob):
|
|||||||
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||||
self.log_dir = self.get_conf('log_dir', None)
|
self.log_dir = self.get_conf('log_dir', None)
|
||||||
|
|
||||||
self.writer = None
|
|
||||||
self.setup_tensorboard()
|
|
||||||
|
|
||||||
# loads the processes from the config
|
# loads the processes from the config
|
||||||
self.load_processes(process_dict)
|
self.load_processes(process_dict)
|
||||||
|
|
||||||
def save_training_config(self):
|
|
||||||
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
|
||||||
os.makedirs(self.training_folder, exist_ok=True)
|
|
||||||
save_dif = os.path.join(self.training_folder, f'run_config_{timestamp}.yaml')
|
|
||||||
with open(save_dif, 'w') as f:
|
|
||||||
yaml.dump(self.raw_config, f)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
@@ -56,12 +48,3 @@ class TrainJob(BaseJob):
|
|||||||
|
|
||||||
for process in self.process:
|
for process in self.process:
|
||||||
process.run()
|
process.run()
|
||||||
|
|
||||||
def setup_tensorboard(self):
|
|
||||||
if self.log_dir:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
now = datetime.now()
|
|
||||||
time_str = now.strftime('%Y%m%d-%H%M%S')
|
|
||||||
summary_name = f"{self.name}_{time_str}"
|
|
||||||
summary_dir = os.path.join(self.log_dir, summary_name)
|
|
||||||
self.writer = SummaryWriter(summary_dir)
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ class BaseProcess(object):
|
|||||||
self.process_id = process_id
|
self.process_id = process_id
|
||||||
self.job = job
|
self.job = job
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.raw_process_config = config
|
||||||
|
self.name = self.get_conf('name', self.job.name)
|
||||||
self.meta = copy.deepcopy(self.job.meta)
|
self.meta = copy.deepcopy(self.job.meta)
|
||||||
print(json.dumps(self.config, indent=4))
|
print(json.dumps(self.config, indent=4))
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.network_config = NetworkConfig(**network_config)
|
self.network_config = NetworkConfig(**network_config)
|
||||||
else:
|
else:
|
||||||
self.network_config = None
|
self.network_config = None
|
||||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
|
||||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||||
@@ -320,8 +319,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unet.train()
|
unet.train()
|
||||||
params += unet.parameters()
|
params += unet.parameters()
|
||||||
|
|
||||||
# TODO recover save if training network. Maybe load from beginning
|
|
||||||
|
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
params = self.hook_add_extra_train_params(params)
|
params = self.hook_add_extra_train_params(params)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,24 @@
|
|||||||
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import ForwardRef
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from jobs.process.BaseProcess import BaseProcess
|
from jobs.process.BaseProcess import BaseProcess
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from jobs import TrainJob, BaseJob, ExtensionJob
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainProcess(BaseProcess):
|
class BaseTrainProcess(BaseProcess):
|
||||||
process_id: int
|
process_id: int
|
||||||
config: OrderedDict
|
config: OrderedDict
|
||||||
progress_bar: ForwardRef('tqdm') = None
|
writer: 'SummaryWriter'
|
||||||
|
job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
|
||||||
|
progress_bar: 'tqdm' = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -18,11 +28,14 @@ class BaseTrainProcess(BaseProcess):
|
|||||||
):
|
):
|
||||||
super().__init__(process_id, job, config)
|
super().__init__(process_id, job, config)
|
||||||
self.progress_bar = None
|
self.progress_bar = None
|
||||||
self.writer = self.job.writer
|
self.writer = None
|
||||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.first_step = 0
|
self.first_step = 0
|
||||||
|
self.log_dir = self.get_conf('log_dir', self.job.log_dir)
|
||||||
|
self.setup_tensorboard()
|
||||||
|
self.save_training_config()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
@@ -37,3 +50,19 @@ class BaseTrainProcess(BaseProcess):
|
|||||||
self.progress_bar.update()
|
self.progress_bar.update()
|
||||||
else:
|
else:
|
||||||
print(*args)
|
print(*args)
|
||||||
|
|
||||||
|
def setup_tensorboard(self):
|
||||||
|
if self.log_dir:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
now = datetime.now()
|
||||||
|
time_str = now.strftime('%Y%m%d-%H%M%S')
|
||||||
|
summary_name = f"{self.name}_{time_str}"
|
||||||
|
summary_dir = os.path.join(self.log_dir, summary_name)
|
||||||
|
self.writer = SummaryWriter(summary_dir)
|
||||||
|
|
||||||
|
def save_training_config(self):
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||||
|
os.makedirs(self.training_folder, exist_ok=True)
|
||||||
|
save_dif = os.path.join(self.training_folder, f'process_config_{timestamp}.yaml')
|
||||||
|
with open(save_dif, 'w') as f:
|
||||||
|
yaml.dump(self.raw_process_config, f)
|
||||||
|
|||||||
@@ -52,18 +52,19 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
|
||||||
|
|
||||||
# read line by line from file
|
# read line by line from file
|
||||||
if self.slider_config.prompt_file:
|
if self.slider_config.prompt_file:
|
||||||
|
self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
|
||||||
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
|
with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
|
||||||
self.prompt_txt_list = f.readlines()
|
self.prompt_txt_list = f.readlines()
|
||||||
# clean empty lines
|
# clean empty lines
|
||||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||||
|
|
||||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
self.print(f"Found {len(self.prompt_txt_list)} prompts.")
|
||||||
|
|
||||||
if not self.slider_config.prompt_tensors:
|
if not self.slider_config.prompt_tensors:
|
||||||
|
print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
|
||||||
# shuffle
|
# shuffle
|
||||||
random.shuffle(self.prompt_txt_list)
|
random.shuffle(self.prompt_txt_list)
|
||||||
# trim to max steps
|
# trim to max steps
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class EncodedPromptPair:
|
|||||||
# simulate torch to for tensors
|
# simulate torch to for tensors
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
self.target_class = self.target_class.to(*args, **kwargs)
|
self.target_class = self.target_class.to(*args, **kwargs)
|
||||||
|
self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
|
||||||
self.positive_target = self.positive_target.to(*args, **kwargs)
|
self.positive_target = self.positive_target.to(*args, **kwargs)
|
||||||
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
|
||||||
self.negative_target = self.negative_target.to(*args, **kwargs)
|
self.negative_target = self.negative_target.to(*args, **kwargs)
|
||||||
@@ -308,7 +309,7 @@ def build_prompt_pair_batch_from_cache(
|
|||||||
prompt_pair_batch = []
|
prompt_pair_batch = []
|
||||||
|
|
||||||
if both or erase_negative:
|
if both or erase_negative:
|
||||||
print("Encoding erase negative")
|
# print("Encoding erase negative")
|
||||||
prompt_pair_batch += [
|
prompt_pair_batch += [
|
||||||
# erase standard
|
# erase standard
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -327,7 +328,7 @@ def build_prompt_pair_batch_from_cache(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
if both or enhance_positive:
|
if both or enhance_positive:
|
||||||
print("Encoding enhance positive")
|
# print("Encoding enhance positive")
|
||||||
prompt_pair_batch += [
|
prompt_pair_batch += [
|
||||||
# enhance standard, swap pos neg
|
# enhance standard, swap pos neg
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -346,7 +347,7 @@ def build_prompt_pair_batch_from_cache(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
if both or enhance_positive:
|
if both or enhance_positive:
|
||||||
print("Encoding erase positive (inverse)")
|
# print("Encoding erase positive (inverse)")
|
||||||
prompt_pair_batch += [
|
prompt_pair_batch += [
|
||||||
# erase inverted
|
# erase inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
@@ -365,7 +366,7 @@ def build_prompt_pair_batch_from_cache(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
if both or erase_negative:
|
if both or erase_negative:
|
||||||
print("Encoding enhance negative (inverse)")
|
# print("Encoding enhance negative (inverse)")
|
||||||
prompt_pair_batch += [
|
prompt_pair_batch += [
|
||||||
# enhance inverted
|
# enhance inverted
|
||||||
EncodedPromptPair(
|
EncodedPromptPair(
|
||||||
|
|||||||
Reference in New Issue
Block a user