Added some further extendability for plugins

This commit is contained in:
Jaret Burkett
2023-09-19 05:41:44 -06:00
parent 61badf85a7
commit 0f105690cc
5 changed files with 70 additions and 65 deletions

View File

@@ -3,7 +3,7 @@ import glob
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Union from typing import Union, List
from diffusers import T2IAdapter from diffusers import T2IAdapter
# from lycoris.config import PRESET # from lycoris.config import PRESET
@@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
self.adapter_config.adapter_type += '_xl' self.adapter_config.adapter_type += '_xl'
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# to hold network if there is one # to hold network if there is one
self.network: Union[Network, None] = None self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None self.adapter: Union[T2IAdapter, None] = None
@@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
self.is_fine_tuning = False self.is_fine_tuning = False
self.named_lora = False
if self.embed_config is not None or self.adapter_config is not None:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
return generate_image_config_list
def sample(self, step=None, is_first=False): def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples') sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = [] gen_img_config_list = []
@@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
**extra_args **extra_args
)) ))
# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
# send to be generated # send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
@@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename) file_path = os.path.join(self.save_root, filename)
# prepare meta # prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name) save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None or self.embedding is not None or self.adapter is not None: if not self.is_fine_tuning:
if self.network is not None: if self.network is not None:
lora_name = self.job.name lora_name = self.job.name
if self.adapter_config is not None or self.embedding is not None: if self.named_lora:
# add _lora to name # add _lora to name
lora_name += '_LoRA' lora_name += '_LoRA'
@@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sigma = sigma.unsqueeze(-1) # sigma = sigma.unsqueeze(-1)
# return sigma # return sigma
def load_additional_training_modules(self, params):
# override in subclass
return params
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad(): with torch.no_grad():
prompts = batch.get_caption_list() prompts = batch.get_caption_list()
@@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ### ### HOOK ###
self.hook_before_model_load() self.hook_before_model_load()
model_config_to_load = copy.deepcopy(self.model_config)
if self.is_fine_tuning:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# run base sd process run # run base sd process run
self.sd.load_model() self.sd.load_model()
@@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd) self.sd)
params = [] params = []
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: if not self.is_fine_tuning:
if self.network_config is not None: if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork? # TODO should we completely switch to LycorisSpecialNetwork?
@@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
lora_name = self.name lora_name = self.name
# need to adapt name so they are not mixed up # need to adapt name so they are not mixed up
if self.adapter_config is not None or self.embedding is not None: if self.named_lora:
lora_name = f"{lora_name}_LoRA" lora_name = f"{lora_name}_LoRA"
latest_save_path = self.get_latest_save_path(lora_name) latest_save_path = self.get_latest_save_path(lora_name)
@@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
}) })
self.sd.adapter = self.adapter self.sd.adapter = self.adapter
flush() flush()
params = self.load_additional_training_modules(params)
else: # no network, embedding or adapter else: # no network, embedding or adapter
# set the device state preset before getting params # set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset) self.sd.set_device_state(self.train_device_state_preset)

View File

@@ -287,14 +287,18 @@ class TrainESRGANProcess(BaseTrainProcess):
self.model.eval() self.model.eval()
def process_and_save(img, target_img, save_path): def process_and_save(img, target_img, save_path):
output = self.model(img.to(self.device, dtype=self.esrgan_dtype)) img = img.to(self.device, dtype=self.esrgan_dtype)
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1) # output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1) output = output.clamp(0, 1)
img = img.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image # convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8)) output = Image.fromarray((output * 255).astype(np.uint8))
img = Image.fromarray((img * 255).astype(np.uint8))
if isinstance(target_img, torch.Tensor): if isinstance(target_img, torch.Tensor):
# convert to pil # convert to pil
@@ -306,16 +310,23 @@ class TrainESRGANProcess(BaseTrainProcess):
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST resample=Image.NEAREST
) )
img = img.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size width, height = output.size
# stack input image and decoded image # stack input image and decoded image
target_image = target_img.resize((width, height)) target_image = target_img.resize((width, height))
output = output.resize((width, height)) output = output.resize((width, height))
img = img.resize((width, height))
output_img = Image.new('RGB', (width * 2, height)) output_img = Image.new('RGB', (width * 3, height))
output_img.paste(target_image, (0, 0))
output_img.paste(img, (0, 0))
output_img.paste(output, (width, 0)) output_img.paste(output, (width, 0))
output_img.paste(target_image, (width * 2, 0))
output_img.save(save_path) output_img.save(save_path)
@@ -346,7 +357,7 @@ class TrainESRGANProcess(BaseTrainProcess):
seconds_since_epoch = int(time.time()) seconds_since_epoch = int(time.time())
# zero-pad 2 digits # zero-pad 2 digits
i_str = str(i).zfill(2) i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(img, target_image, os.path.join(sample_folder, filename)) process_and_save(img, target_image, os.path.join(sample_folder, filename))
if batch is not None: if batch is not None:
@@ -362,7 +373,7 @@ class TrainESRGANProcess(BaseTrainProcess):
seconds_since_epoch = int(time.time()) seconds_since_epoch = int(time.time())
# zero-pad 2 digits # zero-pad 2 digits
i_str = str(i).zfill(2) i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
self.model.train() self.model.train()

View File

@@ -66,7 +66,7 @@ class AdapterConfig:
self.in_channels: int = kwargs.get('in_channels', 3) self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
self.downscale_factor: int = kwargs.get('downscale_factor', 16) self.downscale_factor: int = kwargs.get('downscale_factor', 8)
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
self.image_dir: str = kwargs.get('image_dir', None) self.image_dir: str = kwargs.get('image_dir', None)
self.test_img_path: str = kwargs.get('test_img_path', None) self.test_img_path: str = kwargs.get('test_img_path', None)

View File

@@ -119,13 +119,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
# 'SiLU', # 'SiLU',
# 'ModuleList', # 'ModuleList',
# 'DownBlock2D', # 'DownBlock2D',
'ResnetBlock2D', # need # 'ResnetBlock2D', # need
# 'GroupNorm', # 'GroupNorm',
# 'LoRACompatibleConv', # 'LoRACompatibleConv',
# 'LoRACompatibleLinear', # 'LoRACompatibleLinear',
# 'Dropout', # 'Dropout',
# 'CrossAttnDownBlock2D', # needed # 'CrossAttnDownBlock2D', # needed
'Transformer2DModel', # maybe not, has duplicates # 'Transformer2DModel', # maybe not, has duplicates
# 'BasicTransformerBlock', # duplicates # 'BasicTransformerBlock', # duplicates
# 'LayerNorm', # 'LayerNorm',
# 'Attention', # 'Attention',

View File

@@ -57,35 +57,13 @@ class ToolkitModuleMixin:
self.normalize_scaler = 1.0 self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None self._multiplier: Union[float, list, torch.Tensor] = None
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
def set_multiplier(self: Module, multiplier):
device = self.lora_down.weight.device
dtype = self.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_list = []
for m in multiplier:
if isinstance(m, int) or isinstance(m, float):
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
elif isinstance(m, torch.Tensor):
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
tensor_multiplier = torch.cat(tensor_list)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
self._multiplier = tensor_multiplier.clone().detach()
def _call_forward(self: Module, x): def _call_forward(self: Module, x):
# module dropout # module dropout
if self.module_dropout is not None and self.training: if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout: if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward return 0.0 # added to original forward
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp: if hasattr(self, 'lora_mid') and self.lora_mid is not None:
lx = self.lora_mid(self.lora_down(x)) lx = self.lora_mid(self.lora_down(x))
else: else:
try: try:
@@ -379,7 +357,7 @@ class ToolkitNetworkMixin:
for lora in loras: for lora in loras:
lora.to(device, dtype) lora.to(device, dtype)
def get_all_modules(self: Network): def get_all_modules(self: Network) -> List[Module]:
loras = [] loras = []
if hasattr(self, 'unet_loras'): if hasattr(self, 'unet_loras'):
loras += self.unet_loras loras += self.unet_loras