Added some further extendability for plugins

This commit is contained in:
Jaret Burkett
2023-09-19 05:41:44 -06:00
parent 61badf85a7
commit 0f105690cc
5 changed files with 70 additions and 65 deletions

View File

@@ -3,7 +3,7 @@ import glob
import inspect
from collections import OrderedDict
import os
from typing import Union
from typing import Union, List
from diffusers import T2IAdapter
# from lycoris.config import PRESET
@@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
self.adapter_config.adapter_type += '_xl'
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
@@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
self.is_fine_tuning = False
self.named_lora = False
if self.embed_config is not None or self.adapter_config is not None:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
return generate_image_config_list
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
**extra_args
))
# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
@@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
if self.network is not None or self.embedding is not None or self.adapter is not None:
if not self.is_fine_tuning:
if self.network is not None:
lora_name = self.job.name
if self.adapter_config is not None or self.embedding is not None:
if self.named_lora:
# add _lora to name
lora_name += '_LoRA'
@@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sigma = sigma.unsqueeze(-1)
# return sigma
def load_additional_training_modules(self, params):
# override in subclass
return params
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
@@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
### HOOK ###
self.hook_before_model_load()
model_config_to_load = copy.deepcopy(self.model_config)
if self.is_fine_tuning:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# run base sd process run
self.sd.load_model()
@@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
params = []
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
@@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
lora_name = self.name
# need to adapt name so they are not mixed up
if self.adapter_config is not None or self.embedding is not None:
if self.named_lora:
lora_name = f"{lora_name}_LoRA"
latest_save_path = self.get_latest_save_path(lora_name)
@@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
})
self.sd.adapter = self.adapter
flush()
params = self.load_additional_training_modules(params)
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)

View File

@@ -287,14 +287,18 @@ class TrainESRGANProcess(BaseTrainProcess):
self.model.eval()
def process_and_save(img, target_img, save_path):
output = self.model(img.to(self.device, dtype=self.esrgan_dtype))
img = img.to(self.device, dtype=self.esrgan_dtype)
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
img = img.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
# convert to pillow image
output = Image.fromarray((output * 255).astype(np.uint8))
img = Image.fromarray((img * 255).astype(np.uint8))
if isinstance(target_img, torch.Tensor):
# convert to pil
@@ -306,16 +310,23 @@ class TrainESRGANProcess(BaseTrainProcess):
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
img = img.resize(
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
resample=Image.NEAREST
)
width, height = output.size
# stack input image and decoded image
target_image = target_img.resize((width, height))
output = output.resize((width, height))
img = img.resize((width, height))
output_img = Image.new('RGB', (width * 2, height))
output_img.paste(target_image, (0, 0))
output_img = Image.new('RGB', (width * 3, height))
output_img.paste(img, (0, 0))
output_img.paste(output, (width, 0))
output_img.paste(target_image, (width * 2, 0))
output_img.save(save_path)
@@ -346,7 +357,7 @@ class TrainESRGANProcess(BaseTrainProcess):
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(img, target_image, os.path.join(sample_folder, filename))
if batch is not None:
@@ -362,7 +373,7 @@ class TrainESRGANProcess(BaseTrainProcess):
seconds_since_epoch = int(time.time())
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
self.model.train()

View File

@@ -66,7 +66,7 @@ class AdapterConfig:
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
self.downscale_factor: int = kwargs.get('downscale_factor', 16)
self.downscale_factor: int = kwargs.get('downscale_factor', 8)
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
self.image_dir: str = kwargs.get('image_dir', None)
self.test_img_path: str = kwargs.get('test_img_path', None)

View File

@@ -119,13 +119,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
# 'SiLU',
# 'ModuleList',
# 'DownBlock2D',
'ResnetBlock2D', # need
# 'ResnetBlock2D', # need
# 'GroupNorm',
# 'LoRACompatibleConv',
# 'LoRACompatibleLinear',
# 'Dropout',
# 'CrossAttnDownBlock2D', # needed
'Transformer2DModel', # maybe not, has duplicates
# 'Transformer2DModel', # maybe not, has duplicates
# 'BasicTransformerBlock', # duplicates
# 'LayerNorm',
# 'Attention',

View File

@@ -57,35 +57,13 @@ class ToolkitModuleMixin:
self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
def set_multiplier(self: Module, multiplier):
device = self.lora_down.weight.device
dtype = self.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_list = []
for m in multiplier:
if isinstance(m, int) or isinstance(m, float):
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
elif isinstance(m, torch.Tensor):
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
tensor_multiplier = torch.cat(tensor_list)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
self._multiplier = tensor_multiplier.clone().detach()
def _call_forward(self: Module, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp:
if hasattr(self, 'lora_mid') and self.lora_mid is not None:
lx = self.lora_mid(self.lora_down(x))
else:
try:
@@ -379,7 +357,7 @@ class ToolkitNetworkMixin:
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self: Network):
def get_all_modules(self: Network) -> List[Module]:
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras