mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added some further extendability for plugins
This commit is contained in:
@@ -3,7 +3,7 @@ import glob
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
from diffusers import T2IAdapter
|
||||
# from lycoris.config import PRESET
|
||||
@@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
|
||||
self.adapter_config.adapter_type += '_xl'
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
@@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or self.adapter_config is not None:
|
||||
self.named_lora = True
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
return generate_image_config_list
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# post process
|
||||
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
|
||||
@@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None or self.embedding is not None or self.adapter is not None:
|
||||
if not self.is_fine_tuning:
|
||||
if self.network is not None:
|
||||
lora_name = self.job.name
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
if self.named_lora:
|
||||
# add _lora to name
|
||||
lora_name += '_LoRA'
|
||||
|
||||
@@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# sigma = sigma.unsqueeze(-1)
|
||||
# return sigma
|
||||
|
||||
def load_additional_training_modules(self, params):
|
||||
# override in subclass
|
||||
return params
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
@@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.is_fine_tuning:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
@@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
params = []
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
@@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
lora_name = self.name
|
||||
# need to adapt name so they are not mixed up
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
if self.named_lora:
|
||||
lora_name = f"{lora_name}_LoRA"
|
||||
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
@@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
})
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
Reference in New Issue
Block a user