diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b644f8bb..c85e1d87 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -172,7 +172,9 @@ class SDTrainer(BaseSDTrainProcess): with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) # dont use network on this - self.network.multiplier = 0.0 + # self.network.multiplier = 0.0 + was_network_active = self.network.is_active + self.network.is_active = False self.sd.unet.eval() prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), @@ -187,7 +189,8 @@ class SDTrainer(BaseSDTrainProcess): if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: del pred_kwargs['down_block_additional_residuals'] # restore network - self.network.multiplier = network_weight_list + # self.network.multiplier = network_weight_list + self.network.is_active = was_network_active return prior_pred def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): @@ -197,6 +200,8 @@ class SDTrainer(BaseSDTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list has_adapter_img = batch.control_tensor is not None @@ -234,7 +239,7 @@ class SDTrainer(BaseSDTrainProcess): # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) - mask_multiplier = 1.0 + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: with self.timer('get_mask_multiplier'): # upsampling no supported for bfloat16 @@ -297,107 +302,152 @@ class SDTrainer(BaseSDTrainProcess): self.optimizer.zero_grad(set_to_none=True) # activate network if it exits - with network: - with self.timer('encode_prompt'): - if grad_on_text_encoder: - with torch.set_grad_enabled(True): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( - self.device_torch, - dtype=dtype) - else: - with torch.set_grad_enabled(False): - # make sure it is in eval mode - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.eval() - else: - self.sd.text_encoder.eval() - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( - self.device_torch, - dtype=dtype) - # detach the embeddings - conditional_embeds = conditional_embeds.detach() + # make the batch splits + if self.train_config.single_item_batching: + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in conditioned_prompts] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) - # flush() - pred_kwargs = {} - if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): - with torch.set_grad_enabled(self.adapter is not None): - adapter = self.adapter if self.adapter else self.assistant_adapter - adapter_multiplier = get_adapter_multiplier() + else: + # but it all in an array + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [conditioned_prompts] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + mask_multiplier_list = [mask_multiplier] + + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + mask_multiplier_list + ): + + with network: + with self.timer('encode_prompt'): + if grad_on_text_encoder: + with torch.set_grad_enabled(True): + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + self.device_torch, + dtype=dtype) + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, long_prompts=True).to( + self.device_torch, + dtype=dtype) + + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + + # flush() + pred_kwargs = {} + if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.adapter if self.adapter else self.assistant_adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + + prior_pred = None + if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: + with self.timer('prior predict'): + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + ) + + if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): - down_block_additional_residuals = adapter(adapter_images) - if self.assistant_adapter: - # not training. detach - down_block_additional_residuals = [ - sample.to(dtype=dtype).detach() * adapter_multiplier for sample in - down_block_additional_residuals - ] - else: - down_block_additional_residuals = [ - sample.to(dtype=dtype) * adapter_multiplier for sample in - down_block_additional_residuals - ] + with torch.no_grad(): + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) - pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals - prior_pred = None - if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: - with self.timer('prior predict'): - prior_pred = self.get_prior_prediction( - noisy_latents=noisy_latents, - conditional_embeds=conditional_embeds, - match_adapter_assist=match_adapter_assist, - network_weight_list=network_weight_list, - timesteps=timesteps, - pred_kwargs=pred_kwargs, - noise=noise, - batch=batch, + + with self.timer('predict_unet'): + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs ) - if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): - with self.timer('encode_adapter'): - with torch.no_grad(): - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) - conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + # check if nan + if torch.isnan(loss): + raise ValueError("loss is nan") - with self.timer('predict_unet'): - noise_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs - ) + with self.timer('backward'): + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + loss.backward() - with self.timer('calculate_loss'): - noise = noise.to(self.device_torch, dtype=dtype).detach() - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_pred, - ) - # check if nan - if torch.isnan(loss): - raise ValueError("loss is nan") - - with self.timer('backward'): - # IMPORTANT if gradient checkpointing do not leave with network when doing backward - # it will destroy the gradients. This is because the network is a context manager - # and will change the multipliers back to 0.0 when exiting. They will be - # 0.0 for the backward pass and the gradients will be 0.0 - # I spent weeks on fighting this. DON'T DO IT - loss.backward() - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # flush() with self.timer('optimizer_step'): # apply gradients self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) with self.timer('scheduler_step'): self.lr_scheduler.step() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1476aaec..70ac192f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -460,6 +460,17 @@ class BaseSDTrainProcess(BaseTrainProcess): prompts = batch.get_caption_list() is_reg_list = batch.get_is_reg_list() + is_any_reg = any([is_reg for is_reg in is_reg_list]) + + do_double = self.train_config.short_and_long_captions and not is_any_reg + + if self.train_config.short_and_long_captions and do_double: + # dont do this with regs. No point + + # double batch and add short captions to the end + prompts = prompts + batch.get_caption_short_list() + is_reg_list = is_reg_list + is_reg_list + conditioned_prompts = [] for prompt, is_reg in zip(prompts, is_reg_list): @@ -500,7 +511,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # we determine noise from the differential of the latents unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) - batch_size = latents.shape[0] + batch_size = len(batch.file_items) with self.timer('prepare_noise'): @@ -582,6 +593,21 @@ class BaseSDTrainProcess(BaseTrainProcess): # todo is this for sdxl? find out where this came from originally # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + def double_up_tensor(tensor: torch.Tensor): + if tensor is None: + return None + return torch.cat([tensor, tensor], dim=0) + + if do_double: + noisy_latents = double_up_tensor(noisy_latents) + noise = double_up_tensor(noise) + timesteps = double_up_tensor(timesteps) + # prompts are already updated above + imgs = double_up_tensor(imgs) + batch.mask_tensor = double_up_tensor(batch.mask_tensor) + batch.control_tensor = double_up_tensor(batch.control_tensor) + + # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() @@ -927,16 +953,16 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_train_loop() - if self.has_first_sample_requested: + if self.has_first_sample_requested and self.step_num <= 1: self.print("Generating first sample from first sample config") self.sample(0, is_first=True) # sample first if self.train_config.skip_first_sample: self.print("Skipping first sample due to config setting") - else: + elif self.step_num <= 1: self.print("Generating baseline samples before training") - self.sample(0) + self.sample(self.step_num) self.progress_bar = ToolkitProgressBar( total=self.train_config.steps, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 0d214ae5..60032160 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -125,6 +125,17 @@ class TrainConfig: self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + # short long captions will double your batch size. This only works when a dataset is + # prepared with a json caption file that has both short and long captions in it. It will + # Double up every image and run it through with both short and long captions. The idea + # is that the network will learn how to generate good images with both short and long captions + self.short_and_long_captions = kwargs.get('short_and_long_captions', False) + + # basically gradient accumulation but we run just 1 item through the network + # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful + # for training tricks that increase batch size but need a single gradient step + self.single_item_batching = kwargs.get('single_item_batching', False) + match_adapter_assist = kwargs.get('match_adapter_assist', False) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 32012dc0..ff3b7bf7 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -82,7 +82,7 @@ class DataLoaderBatchDTO: self.control_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None - self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code + self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) @@ -160,6 +160,19 @@ class DataLoaderBatchDTO: add_if_not_present=add_if_not_present ) for x in self.file_items] + def get_caption_short_list( + self, + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + return [x.get_caption( + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + short_caption=False + ) for x in self.file_items] + def cleanup(self): del self.latents del self.tensor diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 9ea062b5..589a22cd 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -55,6 +55,7 @@ transforms_dict = { caption_ext_list = ['txt', 'json', 'caption'] + def clean_caption(caption): # remove any newlines caption = caption.replace('\n', ', ') @@ -227,6 +228,8 @@ class CaptionProcessingDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) + self.raw_caption: str = None + self.raw_caption_short: str = None # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): @@ -235,15 +238,19 @@ class CaptionProcessingDTOMixin: pass elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: self.raw_caption = caption_dict[self.path]["caption"] + if 'caption_short' in caption_dict[self.path]: + self.raw_caption_short = caption_dict[self.path]["caption_short"] else: # see if prompt file exists path_no_ext = os.path.splitext(self.path)[0] prompt_ext = self.dataset_config.caption_ext prompt_path = f"{path_no_ext}.{prompt_ext}" + short_caption = None if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() + short_caption = None if prompt_path.endswith('.json'): # replace any line endings with commas for \n \r \r\n prompt = prompt.replace('\r\n', ' ') @@ -253,32 +260,36 @@ class CaptionProcessingDTOMixin: prompt = json.loads(prompt) if 'caption' in prompt: prompt = prompt['caption'] - # remove any newlines - prompt = prompt.replace('\n', ', ') - # remove new lines for all operating systems - prompt = prompt.replace('\r', ', ') - prompt_split = prompt.split(',') - # remove empty strings - prompt_split = [p.strip() for p in prompt_split if p.strip()] - # join back together - prompt = ', '.join(prompt_split) + if 'caption_short' in prompt: + short_caption = prompt['caption_short'] + prompt = clean_caption(prompt) + if short_caption is not None: + short_caption = clean_caption(short_caption) else: prompt = '' if self.dataset_config.default_caption is not None: prompt = self.dataset_config.default_caption + + if short_caption is None: + short_caption = self.dataset_config.default_caption self.raw_caption = prompt + self.raw_caption_short = short_caption def get_caption( self: 'FileItemDTO', trigger=None, to_replace_list=None, - add_if_not_present=False + add_if_not_present=False, + short_caption=False ): - raw_caption = self.raw_caption + if short_caption: + raw_caption = self.raw_caption_short + else: + raw_caption = self.raw_caption if raw_caption is None: raw_caption = '' # handle dropout - if self.dataset_config.caption_dropout_rate > 0: + if self.dataset_config.caption_dropout_rate > 0 and not short_caption: # get a random float form 0 to 1 rand = random.random() if rand < self.dataset_config.caption_dropout_rate: @@ -296,7 +307,7 @@ class CaptionProcessingDTOMixin: random.shuffle(token_list) # handle token dropout - if self.dataset_config.token_dropout_rate > 0: + if self.dataset_config.token_dropout_rate > 0 and not short_caption: new_token_list = [] for token in token_list: # get a random float form 0 to 1 @@ -845,7 +856,8 @@ class LatentCachingMixin: self.sd.set_device_state_preset('cache_latents') # use tqdm to show progress - for i, file_item in tqdm(enumerate(self.file_list), desc=f'Caching latents{" to disk" if to_disk else ""}'): + i = 0 + for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): # set latent space version if self.sd.is_xl: file_item.latent_space_version = 'sdxl' @@ -891,6 +903,7 @@ class LatentCachingMixin: flush(garbage_collect=False) file_item.is_latent_cached = True + i += 1 # flush every 100 # if i % 100 == 0: # flush() diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 5e599227..21d5cc5c 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -89,7 +89,8 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier - self.org_module = org_module # remove in applying + # wrap the original module so it doesn't get weights updated + self.org_module = [org_module] self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout @@ -98,9 +99,9 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): self.normalize_scaler = 1.0 def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + # del self.org_module class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): @@ -170,6 +171,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_merged_in = False if modules_dim is not None: print(f"create LoRA network from weights") diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index d007de6c..f38556e7 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -42,6 +42,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): self.lora_dim = lora_dim self.cp = False + self.scalar = nn.Parameter(torch.tensor(0.0)) orig_module_name = org_module.__class__.__name__ if orig_module_name in CONV_MODULES: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index da59f4a0..c62de64b 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -103,7 +103,21 @@ class ToolkitModuleMixin: # this may get an additional positional arg or not def forward(self: Module, x, *args, **kwargs): - if not self.network_ref().is_active: + skip = False + network = self.network_ref() + # skip if not active + if not network.is_active: + skip = True + + # skip if is merged in + if network.is_merged_in: + skip = True + + # skip if multiplier is 0 + if network._multiplier == 0: + skip = True + + if skip: # network is not active, avoid doing anything return self.org_forward(x, *args, **kwargs) @@ -191,6 +205,52 @@ class ToolkitModuleMixin: # reset the normalization scaler self.normalize_scaler = target_normalize_scaler + @torch.no_grad() + def merge_out(self: Module, merge_out_weight=1.0): + # make sure it is positive + merge_out_weight = abs(merge_out_weight) + # merging out is just merging in the negative of the weight + self.merge_in(merge_weight=-merge_out_weight) + + @torch.no_grad() + def merge_in(self: Module, merge_weight=1.0): + # get up/down weight + up_weight = self.lora_up.weight.clone().float() + down_weight = self.lora_down.weight.clone().float() + + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + orig_dtype = org_sd["weight"].dtype + weight = org_sd["weight"].float() + + multiplier = merge_weight + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # print(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + multiplier * conved * scale + + # set weight to org_module + org_sd["weight"] = weight.to(orig_dtype) + self.org_module[0].load_state_dict(org_sd) + class ToolkitNetworkMixin: def __init__( @@ -210,6 +270,7 @@ class ToolkitNetworkMixin: self._is_normalizing: bool = False self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_merged_in = False # super().__init__(*args, **kwargs) def get_keymap(self: Network): @@ -326,7 +387,6 @@ class ToolkitNetworkMixin: self.torch_multiplier = tensor_multiplier.clone().detach() - @property def multiplier(self) -> Union[float, List[float]]: return self._multiplier @@ -396,3 +456,15 @@ class ToolkitNetworkMixin: def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0): for module in self.get_all_modules(): module.apply_stored_normalizer(target_normalize_scaler) + + def merge_in(self, merge_weight=1.0): + self.is_merged_in = True + for module in self.get_all_modules(): + module.merge_in(merge_weight) + + def merge_out(self, merge_weight=1.0): + if not self.is_merged_in: + return + self.is_merged_in = False + for module in self.get_all_modules(): + module.merge_out(merge_weight) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index dd17406d..effa014d 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -62,6 +62,7 @@ class BlankNetwork: self.multiplier = 1.0 self.is_active = True self.is_normalizing = False + self.is_merged_in = False def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): pass @@ -267,10 +268,18 @@ class StableDiffusion: @torch.no_grad() def generate_images(self, image_configs: List[GenerateImageConfig], sampler=None): + merge_multiplier = 1.0 # sample_folder = os.path.join(self.save_root, 'samples') if self.network is not None: self.network.eval() network = self.network + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set([x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) else: network = BlankNetwork() @@ -462,6 +471,9 @@ class StableDiffusion: self.network.train() self.network.multiplier = start_multiplier self.network.is_normalizing = was_network_normalizing + + if network.is_merged_in: + network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) def get_latent_noise(