From 002279cec335268639355a793aee60c4f4ba6f73 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 24 Oct 2023 16:02:07 -0600 Subject: [PATCH] Allow short and long caption combinations like form the new captioning system. Merge the network into the model before inference and reextract when done. Doubles inference speed on locon models during inference. allow splitting a batch into individual components and run them through alone. Basicallt gradient accumulation with single batch size. --- extensions_built_in/sd_trainer/SDTrainer.py | 230 ++++++++++++-------- jobs/process/BaseSDTrainProcess.py | 34 ++- toolkit/config_modules.py | 11 + toolkit/data_transfer_object/data_loader.py | 15 +- toolkit/dataloader_mixins.py | 41 ++-- toolkit/lora_special.py | 10 +- toolkit/lycoris_special.py | 1 + toolkit/network_mixins.py | 76 ++++++- toolkit/stable_diffusion_model.py | 12 + 9 files changed, 315 insertions(+), 115 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b644f8bb..c85e1d87 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess): with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) # dont use network on this - self.network.multiplier = 0.0 + # self.network.multiplier = 0.0 + was_network_active = self.network.is_active + self.network.is_active = False self.sd.unet.eval() prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), @@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess): if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: del pred_kwargs['down_block_additional_residuals'] # restore network - self.network.multiplier = network_weight_list + # self.network.multiplier = network_weight_list + self.network.is_active = was_network_active return prior_pred def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): @@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list has_adapter_img = batch.control_tensor is not None @@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess): # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) - mask_multiplier = 1.0 + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: with self.timer('get_mask_multiplier'): # upsampling no supported for bfloat16 @@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess): self.optimizer.zero_grad(set_to_none=True) # activate network if it exits - with network: - with self.timer('encode_prompt'): - if grad_on_text_encoder: - with torch.set_grad_enabled(True): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( - self.device_torch, - dtype=dtype) - else: - with torch.set_grad_enabled(False): - # make sure it is in eval mode - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.eval() - else: - self.sd.text_encoder.eval() - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( - self.device_torch, - dtype=dtype) - # detach the embeddings - conditional_embeds = conditional_embeds.detach() + # make the batch splits + if self.train_config.single_item_batching: + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) - # flush() - pred_kwargs = {} - if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): - with torch.set_grad_enabled(self.adapter is not None): - adapter = self.adapter if self.adapter else self.assistant_adapter - adapter_multiplier = get_adapter_multiplier() + else: + # but it all in an array + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [conditioned_prompts] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + mask_multiplier_list = [mask_multiplier] + + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + mask_multiplier_list + ): + + with network: + with self.timer('encode_prompt'): + if grad_on_text_encoder: + with torch.set_grad_enabled(True): + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + self.device_torch, + dtype=dtype) + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + self.device_torch, + dtype=dtype) + + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + + # flush() + pred_kwargs = {} + if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.adapter if self.adapter else self.assistant_adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + prior_pred = None + if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: + with self.timer('prior predict'): + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + ) + + if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): - down_block_additional_residuals = adapter(adapter_images) - if self.assistant_adapter: - # not training. detach - down_block_additional_residuals = [ - sample.to(dtype=dtype).detach() * adapter_multiplier for sample in - down_block_additional_residuals - ] - else: - down_block_additional_residuals = [ - sample.to(dtype=dtype) * adapter_multiplier for sample in - down_block_additional_residuals - ] + with torch.no_grad(): + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) - pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals - prior_pred = None - if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: - with self.timer('prior predict'): - prior_pred = self.get_prior_prediction( - noisy_latents=noisy_latents, - conditional_embeds=conditional_embeds, - match_adapter_assist=match_adapter_assist, - network_weight_list=network_weight_list, - timesteps=timesteps, - pred_kwargs=pred_kwargs, - noise=noise, - batch=batch, + + with self.timer('predict_unet'): + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs ) - if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): - with self.timer('encode_adapter'): - with torch.no_grad(): - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) - conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + # check if nan + if torch.isnan(loss): + raise ValueError("loss is nan") - with self.timer('predict_unet'): - noise_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs - ) + with self.timer('backward'): + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + loss.backward() - with self.timer('calculate_loss'): - noise = noise.to(self.device_torch, dtype=dtype).detach() - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_pred, - ) - # check if nan - if torch.isnan(loss): - raise ValueError("loss is nan") - - with self.timer('backward'): - # IMPORTANT if gradient checkpointing do not leave with network when doing backward - # it will destroy the gradients. This is because the network is a context manager - # and will change the multipliers back to 0.0 when exiting. They will be - # 0.0 for the backward pass and the gradients will be 0.0 - # I spent weeks on fighting this. DON'T DO IT - loss.backward() - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # flush() with self.timer('optimizer_step'): # apply gradients self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) with self.timer('scheduler_step'): self.lr_scheduler.step() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1476aaec..70ac192f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -460,6 +460,17 @@ class BaseSDTrainProcess(BaseTrainProcess): prompts = batch.get_caption_list() is_reg_list = batch.get_is_reg_list() + is_any_reg = any([is_reg for is_reg in is_reg_list]) + + do_double = self.train_config.short_and_long_captions and not is_any_reg + + if self.train_config.short_and_long_captions and do_double: + # dont do this with regs. No point + + # double batch and add short captions to the end + prompts = prompts + batch.get_caption_short_list() + is_reg_list = is_reg_list + is_reg_list + conditioned_prompts = [] for prompt, is_reg in zip(prompts, is_reg_list): @@ -500,7 +511,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # we determine noise from the differential of the latents unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) - batch_size = latents.shape[0] + batch_size = len(batch.file_items) with self.timer('prepare_noise'): @@ -582,6 +593,21 @@ class BaseSDTrainProcess(BaseTrainProcess): # todo is this for sdxl? find out where this came from originally # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + def double_up_tensor(tensor: torch.Tensor): + if tensor is None: + return None + return torch.cat([tensor, tensor], dim=0) + + if do_double: + noisy_latents = double_up_tensor(noisy_latents) + noise = double_up_tensor(noise) + timesteps = double_up_tensor(timesteps) + # prompts are already updated above + imgs = double_up_tensor(imgs) + batch.mask_tensor = double_up_tensor(batch.mask_tensor) + batch.control_tensor = double_up_tensor(batch.control_tensor) + + # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() @@ -927,16 +953,16 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_train_loop() - if self.has_first_sample_requested: + if self.has_first_sample_requested and self.step_num <= 1: self.print("Generating first sample from first sample config") self.sample(0, is_first=True) # sample first if self.train_config.skip_first_sample: self.print("Skipping first sample due to config setting") - else: + elif self.step_num <= 1: self.print("Generating baseline samples before training") - self.sample(0) + self.sample(self.step_num) self.progress_bar = ToolkitProgressBar( total=self.train_config.steps, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 0d214ae5..60032160 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -125,6 +125,17 @@ class TrainConfig: self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + # short long captions will double your batch size. This only works when a dataset is + # prepared with a json caption file that has both short and long captions in it. It will + # Double up every image and run it through with both short and long captions. The idea + # is that the network will learn how to generate good images with both short and long captions + self.short_and_long_captions = kwargs.get('short_and_long_captions', False) + + # basically gradient accumulation but we run just 1 item through the network + # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful + # for training tricks that increase batch size but need a single gradient step + self.single_item_batching = kwargs.get('single_item_batching', False) + match_adapter_assist = kwargs.get('match_adapter_assist', False) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 32012dc0..ff3b7bf7 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -82,7 +82,7 @@ class DataLoaderBatchDTO: self.control_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None - self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code + self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) @@ -160,6 +160,19 @@ class DataLoaderBatchDTO: add_if_not_present=add_if_not_present ) for x in self.file_items] + def get_caption_short_list( + self, + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + return [x.get_caption( + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + short_caption=False + ) for x in self.file_items] + def cleanup(self): del self.latents del self.tensor diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 9ea062b5..589a22cd 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -55,6 +55,7 @@ transforms_dict = { caption_ext_list = ['txt', 'json', 'caption'] + def clean_caption(caption): # remove any newlines caption = caption.replace('\n', ', ') @@ -227,6 +228,8 @@ class CaptionProcessingDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) + self.raw_caption: str = None + self.raw_caption_short: str = None # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): @@ -235,15 +238,19 @@ class CaptionProcessingDTOMixin: pass elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: self.raw_caption = caption_dict[self.path]["caption"] + if 'caption_short' in caption_dict[self.path]: + self.raw_caption_short = caption_dict[self.path]["caption_short"] else: # see if prompt file exists path_no_ext = os.path.splitext(self.path)[0] prompt_ext = self.dataset_config.caption_ext prompt_path = f"{path_no_ext}.{prompt_ext}" + short_caption = None if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() + short_caption = None if prompt_path.endswith('.json'): # replace any line endings with commas for \n \r \r\n prompt = prompt.replace('\r\n', ' ') @@ -253,32 +260,36 @@ class CaptionProcessingDTOMixin: prompt = json.loads(prompt) if 'caption' in prompt: prompt = prompt['caption'] - # remove any newlines - prompt = prompt.replace('\n', ', ') - # remove new lines for all operating systems - prompt = prompt.replace('\r', ', ') - prompt_split = prompt.split(',') - # remove empty strings - prompt_split = [p.strip() for p in prompt_split if p.strip()] - # join back together - prompt = ', '.join(prompt_split) + if 'caption_short' in prompt: + short_caption = prompt['caption_short'] + prompt = clean_caption(prompt) + if short_caption is not None: + short_caption = clean_caption(short_caption) else: prompt = '' if self.dataset_config.default_caption is not None: prompt = self.dataset_config.default_caption + + if short_caption is None: + short_caption = self.dataset_config.default_caption self.raw_caption = prompt + self.raw_caption_short = short_caption def get_caption( self: 'FileItemDTO', trigger=None, to_replace_list=None, - add_if_not_present=False + add_if_not_present=False, + short_caption=False ): - raw_caption = self.raw_caption + if short_caption: + raw_caption = self.raw_caption_short + else: + raw_caption = self.raw_caption if raw_caption is None: raw_caption = '' # handle dropout - if self.dataset_config.caption_dropout_rate > 0: + if self.dataset_config.caption_dropout_rate > 0 and not short_caption: # get a random float form 0 to 1 rand = random.random() if rand < self.dataset_config.caption_dropout_rate: @@ -296,7 +307,7 @@ class CaptionProcessingDTOMixin: random.shuffle(token_list) # handle token dropout - if self.dataset_config.token_dropout_rate > 0: + if self.dataset_config.token_dropout_rate > 0 and not short_caption: new_token_list = [] for token in token_list: # get a random float form 0 to 1 @@ -845,7 +856,8 @@ class LatentCachingMixin: self.sd.set_device_state_preset('cache_latents') # use tqdm to show progress - for i, file_item in tqdm(enumerate(self.file_list), desc=f'Caching latents{" to disk" if to_disk else ""}'): + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): # set latent space version if self.sd.is_xl: file_item.latent_space_version = 'sdxl' @@ -891,6 +903,7 @@ class LatentCachingMixin: flush(garbage_collect=False) file_item.is_latent_cached = True + i += 1 # flush every 100 # if i % 100 == 0: # flush() diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 5e599227..21d5cc5c 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -89,7 +89,8 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier - self.org_module = org_module # remove in applying + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout @@ -98,9 +99,9 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): self.normalize_scaler = 1.0 def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): @@ -170,6 +171,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_merged_in = False if modules_dim is not None: print(f"create LoRA network from weights") diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index d007de6c..f38556e7 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -42,6 +42,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): self.lora_dim = lora_dim self.cp = False + self.scalar = nn.Parameter(torch.tensor(0.0)) orig_module_name = org_module.__class__.__name__ if orig_module_name in CONV_MODULES: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index da59f4a0..c62de64b 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -103,7 +103,21 @@ class ToolkitModuleMixin: # this may get an additional positional arg or not def forward(self: Module, x, *args, **kwargs): - if not self.network_ref().is_active: + skip = False + network = self.network_ref() + # skip if not active + if not network.is_active: + skip = True + + # skip if is merged in + if network.is_merged_in: + skip = True + + # skip if multiplier is 0 + if network._multiplier == 0: + skip = True + + if skip: # network is not active, avoid doing anything return self.org_forward(x, *args, **kwargs) @@ -191,6 +205,52 @@ class ToolkitModuleMixin: # reset the normalization scaler self.normalize_scaler = target_normalize_scaler + @torch.no_grad() + def merge_out(self: Module, merge_out_weight=1.0): + # make sure it is positive + merge_out_weight = abs(merge_out_weight) + # merging out is just merging in the negative of the weight + self.merge_in(merge_weight=-merge_out_weight) + + @torch.no_grad() + def merge_in(self: Module, merge_weight=1.0): + # get up/down weight + up_weight = self.lora_up.weight.clone().float() + down_weight = self.lora_down.weight.clone().float() + + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + orig_dtype = org_sd["weight"].dtype + weight = org_sd["weight"].float() + + multiplier = merge_weight + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + multiplier * conved * scale + + # set weight to org_module + org_sd["weight"] = weight.to(orig_dtype) + self.org_module[0].load_state_dict(org_sd) + class ToolkitNetworkMixin: def __init__( @@ -210,6 +270,7 @@ class ToolkitNetworkMixin: self._is_normalizing: bool = False self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_merged_in = False # super().__init__(*args, **kwargs) def get_keymap(self: Network): @@ -326,7 +387,6 @@ class ToolkitNetworkMixin: self.torch_multiplier = tensor_multiplier.clone().detach() - @property def multiplier(self) -> Union[float, List[float]]: return self._multiplier @@ -396,3 +456,15 @@ class ToolkitNetworkMixin: def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0): for module in self.get_all_modules(): module.apply_stored_normalizer(target_normalize_scaler) + + def merge_in(self, merge_weight=1.0): + self.is_merged_in = True + for module in self.get_all_modules(): + module.merge_in(merge_weight) + + def merge_out(self, merge_weight=1.0): + if not self.is_merged_in: + return + self.is_merged_in = False + for module in self.get_all_modules(): + module.merge_out(merge_weight) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index dd17406d..effa014d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -62,6 +62,7 @@ class BlankNetwork: self.multiplier = 1.0 self.is_active = True self.is_normalizing = False + self.is_merged_in = False def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): pass @@ -267,10 +268,18 @@ class StableDiffusion: @torch.no_grad() def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None): + merge_multiplier = 1.0 # sample_folder = os.path.join(self.save_root, 'samples') if self.network is not None: self.network.eval() network = self.network + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set([x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) else: network = BlankNetwork() @@ -462,6 +471,9 @@ class StableDiffusion: self.network.train() self.network.multiplier = start_multiplier self.network.is_normalizing = was_network_normalizing + + if network.is_merged_in: + network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) def get_latent_noise(