mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added some further extendability for plugins
This commit is contained in:
@@ -3,7 +3,7 @@ import glob
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
from diffusers import T2IAdapter
|
||||
# from lycoris.config import PRESET
|
||||
@@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
|
||||
self.adapter_config.adapter_type += '_xl'
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None and self.adapter_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
@@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or self.adapter_config is not None:
|
||||
self.named_lora = True
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
return generate_image_config_list
|
||||
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
gen_img_config_list = []
|
||||
@@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# post process
|
||||
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
|
||||
@@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
if self.network is not None or self.embedding is not None or self.adapter is not None:
|
||||
if not self.is_fine_tuning:
|
||||
if self.network is not None:
|
||||
lora_name = self.job.name
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
if self.named_lora:
|
||||
# add _lora to name
|
||||
lora_name += '_LoRA'
|
||||
|
||||
@@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# sigma = sigma.unsqueeze(-1)
|
||||
# return sigma
|
||||
|
||||
def load_additional_training_modules(self, params):
|
||||
# override in subclass
|
||||
return params
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
@@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.is_fine_tuning:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
@@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
params = []
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
@@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
lora_name = self.name
|
||||
# need to adapt name so they are not mixed up
|
||||
if self.adapter_config is not None or self.embedding is not None:
|
||||
if self.named_lora:
|
||||
lora_name = f"{lora_name}_LoRA"
|
||||
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
@@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
})
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
@@ -287,14 +287,18 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
self.model.eval()
|
||||
|
||||
def process_and_save(img, target_img, save_path):
|
||||
output = self.model(img.to(self.device, dtype=self.esrgan_dtype))
|
||||
img = img.to(self.device, dtype=self.esrgan_dtype)
|
||||
output = self.model(img)
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
img = img.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
img = Image.fromarray((img * 255).astype(np.uint8))
|
||||
|
||||
if isinstance(target_img, torch.Tensor):
|
||||
# convert to pil
|
||||
@@ -306,16 +310,23 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
img = img.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_img.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
img = img.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 2, height))
|
||||
output_img.paste(target_image, (0, 0))
|
||||
output_img = Image.new('RGB', (width * 3, height))
|
||||
|
||||
output_img.paste(img, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
output_img.paste(target_image, (width * 2, 0))
|
||||
|
||||
output_img.save(save_path)
|
||||
|
||||
@@ -346,7 +357,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
|
||||
process_and_save(img, target_image, os.path.join(sample_folder, filename))
|
||||
|
||||
if batch is not None:
|
||||
@@ -362,7 +373,7 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
|
||||
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
|
||||
|
||||
self.model.train()
|
||||
|
||||
@@ -66,7 +66,7 @@ class AdapterConfig:
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 16)
|
||||
self.downscale_factor: int = kwargs.get('downscale_factor', 8)
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
|
||||
@@ -119,13 +119,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
# 'SiLU',
|
||||
# 'ModuleList',
|
||||
# 'DownBlock2D',
|
||||
'ResnetBlock2D', # need
|
||||
# 'ResnetBlock2D', # need
|
||||
# 'GroupNorm',
|
||||
# 'LoRACompatibleConv',
|
||||
# 'LoRACompatibleLinear',
|
||||
# 'Dropout',
|
||||
# 'CrossAttnDownBlock2D', # needed
|
||||
'Transformer2DModel', # maybe not, has duplicates
|
||||
# 'Transformer2DModel', # maybe not, has duplicates
|
||||
# 'BasicTransformerBlock', # duplicates
|
||||
# 'LayerNorm',
|
||||
# 'Attention',
|
||||
|
||||
@@ -57,35 +57,13 @@ class ToolkitModuleMixin:
|
||||
self.normalize_scaler = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
# this allows us to set different multipliers on a per item in a batch basis
|
||||
# allowing us to run positive and negative weights in the same batch
|
||||
def set_multiplier(self: Module, multiplier):
|
||||
device = self.lora_down.weight.device
|
||||
dtype = self.lora_down.weight.dtype
|
||||
with torch.no_grad():
|
||||
tensor_multiplier = None
|
||||
if isinstance(multiplier, int) or isinstance(multiplier, float):
|
||||
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
|
||||
elif isinstance(multiplier, list):
|
||||
tensor_list = []
|
||||
for m in multiplier:
|
||||
if isinstance(m, int) or isinstance(m, float):
|
||||
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
|
||||
elif isinstance(m, torch.Tensor):
|
||||
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
|
||||
tensor_multiplier = torch.cat(tensor_list)
|
||||
elif isinstance(multiplier, torch.Tensor):
|
||||
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
|
||||
|
||||
self._multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
def _call_forward(self: Module, x):
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return 0.0 # added to original forward
|
||||
|
||||
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp:
|
||||
if hasattr(self, 'lora_mid') and self.lora_mid is not None:
|
||||
lx = self.lora_mid(self.lora_down(x))
|
||||
else:
|
||||
try:
|
||||
@@ -379,7 +357,7 @@ class ToolkitNetworkMixin:
|
||||
for lora in loras:
|
||||
lora.to(device, dtype)
|
||||
|
||||
def get_all_modules(self: Network):
|
||||
def get_all_modules(self: Network) -> List[Module]:
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
|
||||
Reference in New Issue
Block a user