diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 90f11354..927286e4 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -3,7 +3,7 @@ import glob import inspect from collections import OrderedDict import os -from typing import Union +from typing import Union, List from diffusers import T2IAdapter # from lycoris.config import PRESET @@ -116,34 +116,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'): self.adapter_config.adapter_type += '_xl' - model_config_to_load = copy.deepcopy(self.model_config) - - if self.embed_config is None and self.network_config is None and self.adapter_config is None: - # get the latest checkpoint - # check to see if we have a latest save - latest_save_path = self.get_latest_save_path() - - if latest_save_path is not None: - print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - model_config_to_load.name_or_path = latest_save_path - meta = load_metadata_from_safetensors(latest_save_path) - # if 'training_info' in Orderdict keys - if 'training_info' in meta and 'step' in meta['training_info']: - self.step_num = meta['training_info']['step'] - self.start_step = self.step_num - print(f"Found step {self.step_num} in metadata, starting from there") - - # get the noise scheduler - sampler = get_sampler(self.train_config.noise_scheduler) - - self.sd = StableDiffusion( - device=self.device, - model_config=model_config_to_load, - dtype=self.train_config.dtype, - custom_pipeline=self.custom_pipeline, - noise_scheduler=sampler, - ) - # to hold network if there is one self.network: Union[Network, None] = None self.adapter: Union[T2IAdapter, None] = None @@ -165,6 +137,13 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: self.is_fine_tuning = False + self.named_lora = False + if self.embed_config is not None or self.adapter_config is not None: + self.named_lora = True + def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): + # override in subclass + return generate_image_config_list + def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -218,6 +197,9 @@ class BaseSDTrainProcess(BaseTrainProcess): **extra_args )) + # post process + gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list) + # send to be generated self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) @@ -297,10 +279,10 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) - if self.network is not None or self.embedding is not None or self.adapter is not None: + if not self.is_fine_tuning: if self.network is not None: lora_name = self.job.name - if self.adapter_config is not None or self.embedding is not None: + if self.named_lora: # add _lora to name lora_name += '_LoRA' @@ -438,6 +420,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # sigma = sigma.unsqueeze(-1) # return sigma + def load_additional_training_modules(self, params): + # override in subclass + return params + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): prompts = batch.get_caption_list() @@ -548,6 +534,33 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_model_load() + model_config_to_load = copy.deepcopy(self.model_config) + + if self.is_fine_tuning: + # get the latest checkpoint + # check to see if we have a latest save + latest_save_path = self.get_latest_save_path() + + if latest_save_path is not None: + print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + model_config_to_load.name_or_path = latest_save_path + meta = load_metadata_from_safetensors(latest_save_path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + + # get the noise scheduler + sampler = get_sampler(self.train_config.noise_scheduler) + + self.sd = StableDiffusion( + device=self.device, + model_config=model_config_to_load, + dtype=self.train_config.dtype, + custom_pipeline=self.custom_pipeline, + noise_scheduler=sampler, + ) # run base sd process run self.sd.load_model() @@ -611,7 +624,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd) params = [] - if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: + if not self.is_fine_tuning: if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? @@ -678,7 +691,7 @@ class BaseSDTrainProcess(BaseTrainProcess): lora_name = self.name # need to adapt name so they are not mixed up - if self.adapter_config is not None or self.embedding is not None: + if self.named_lora: lora_name = f"{lora_name}_LoRA" latest_save_path = self.get_latest_save_path(lora_name) @@ -758,6 +771,9 @@ class BaseSDTrainProcess(BaseTrainProcess): }) self.sd.adapter = self.adapter flush() + + params = self.load_additional_training_modules(params) + else: # no network, embedding or adapter # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py index c00d92bd..4ff3a69d 100644 --- a/jobs/process/TrainESRGANProcess.py +++ b/jobs/process/TrainESRGANProcess.py @@ -287,14 +287,18 @@ class TrainESRGANProcess(BaseTrainProcess): self.model.eval() def process_and_save(img, target_img, save_path): - output = self.model(img.to(self.device, dtype=self.esrgan_dtype)) + img = img.to(self.device, dtype=self.esrgan_dtype) + output = self.model(img) # output = (output / 2 + 0.5).clamp(0, 1) output = output.clamp(0, 1) + img = img.clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() # convert to pillow image output = Image.fromarray((output * 255).astype(np.uint8)) + img = Image.fromarray((img * 255).astype(np.uint8)) if isinstance(target_img, torch.Tensor): # convert to pil @@ -306,16 +310,23 @@ class TrainESRGANProcess(BaseTrainProcess): (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), resample=Image.NEAREST ) + img = img.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) width, height = output.size # stack input image and decoded image target_image = target_img.resize((width, height)) output = output.resize((width, height)) + img = img.resize((width, height)) - output_img = Image.new('RGB', (width * 2, height)) - output_img.paste(target_image, (0, 0)) + output_img = Image.new('RGB', (width * 3, height)) + + output_img.paste(img, (0, 0)) output_img.paste(output, (width, 0)) + output_img.paste(target_image, (width * 2, 0)) output_img.save(save_path) @@ -346,7 +357,7 @@ class TrainESRGANProcess(BaseTrainProcess): seconds_since_epoch = int(time.time()) # zero-pad 2 digits i_str = str(i).zfill(2) - filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" process_and_save(img, target_image, os.path.join(sample_folder, filename)) if batch is not None: @@ -362,7 +373,7 @@ class TrainESRGANProcess(BaseTrainProcess): seconds_since_epoch = int(time.time()) # zero-pad 2 digits i_str = str(i).zfill(2) - filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) self.model.train() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 9bf76ab4..bd76b7b4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -66,7 +66,7 @@ class AdapterConfig: self.in_channels: int = kwargs.get('in_channels', 3) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) - self.downscale_factor: int = kwargs.get('downscale_factor', 16) + self.downscale_factor: int = kwargs.get('downscale_factor', 8) self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') self.image_dir: str = kwargs.get('image_dir', None) self.test_img_path: str = kwargs.get('test_img_path', None) diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index 05df65b3..36cdd8b7 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -119,13 +119,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): # 'SiLU', # 'ModuleList', # 'DownBlock2D', - 'ResnetBlock2D', # need + # 'ResnetBlock2D', # need # 'GroupNorm', # 'LoRACompatibleConv', # 'LoRACompatibleLinear', # 'Dropout', # 'CrossAttnDownBlock2D', # needed - 'Transformer2DModel', # maybe not, has duplicates + # 'Transformer2DModel', # maybe not, has duplicates # 'BasicTransformerBlock', # duplicates # 'LayerNorm', # 'Attention', diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 925bb3b4..b592670f 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -57,35 +57,13 @@ class ToolkitModuleMixin: self.normalize_scaler = 1.0 self._multiplier: Union[float, list, torch.Tensor] = None - # this allows us to set different multipliers on a per item in a batch basis - # allowing us to run positive and negative weights in the same batch - def set_multiplier(self: Module, multiplier): - device = self.lora_down.weight.device - dtype = self.lora_down.weight.dtype - with torch.no_grad(): - tensor_multiplier = None - if isinstance(multiplier, int) or isinstance(multiplier, float): - tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) - elif isinstance(multiplier, list): - tensor_list = [] - for m in multiplier: - if isinstance(m, int) or isinstance(m, float): - tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype)) - elif isinstance(m, torch.Tensor): - tensor_list.append(m.clone().detach().to(device, dtype=dtype)) - tensor_multiplier = torch.cat(tensor_list) - elif isinstance(multiplier, torch.Tensor): - tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) - - self._multiplier = tensor_multiplier.clone().detach() - def _call_forward(self: Module, x): # module dropout if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return 0.0 # added to original forward - if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp: + if hasattr(self, 'lora_mid') and self.lora_mid is not None: lx = self.lora_mid(self.lora_down(x)) else: try: @@ -379,7 +357,7 @@ class ToolkitNetworkMixin: for lora in loras: lora.to(device, dtype) - def get_all_modules(self: Network): + def get_all_modules(self: Network) -> List[Module]: loras = [] if hasattr(self, 'unet_loras'): loras += self.unet_loras