Added some further extendability for plugins

This commit is contained in:
Jaret Burkett
2023-09-19 05:41:44 -06:00
parent 61badf85a7
commit 0f105690cc
5 changed files with 70 additions and 65 deletions

View File

@@ -3,7 +3,7 @@ import glob
import inspect
from collections import OrderedDict
import os
from typing import Union
from typing import Union, List
from diffusers import T2IAdapter
# from lycoris.config import PRESET
@@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
self.adapter_config.adapter_type += '_xl'
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
@@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
self.is_fine_tuning = False
self.named_lora = False
if self.embed_config is not None or self.adapter_config is not None:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
return generate_image_config_list
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
**extra_args
))
# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
@@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None or self.embedding is not None or self.adapter is not None:
if not self.is_fine_tuning:
if self.network is not None:
lora_name = self.job.name
if self.adapter_config is not None or self.embedding is not None:
if self.named_lora:
# add _lora to name
lora_name += '_LoRA'
@@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sigma = sigma.unsqueeze(-1)
# return sigma
def load_additional_training_modules(self, params):
# override in subclass
return params
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
@@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_model_load()
model_config_to_load = copy.deepcopy(self.model_config)
if self.is_fine_tuning:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# run base sd process run
self.sd.load_model()
@@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
params = []
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
@@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
lora_name = self.name
# need to adapt name so they are not mixed up
if self.adapter_config is not None or self.embedding is not None:
if self.named_lora:
lora_name = f"{lora_name}_LoRA"
latest_save_path = self.get_latest_save_path(lora_name)
@@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
})
self.sd.adapter = self.adapter
flush()
params = self.load_additional_training_modules(params)
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)