diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 2e455859..ad52b093 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -1,4 +1,5 @@ import copy +import torch from ldm_patched.modules.model_patcher import ModelPatcher from ldm_patched.modules.sample import convert_cond @@ -42,6 +43,23 @@ class UnetPatcher(ModelPatcher): self.extra_model_patchers_during_sampling.append(model_patcher) return + def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True): + # Use this method to bind an extra torch.nn.Module to this UNet during sampling. + # This model `m` will be delegated to Forge memory management system. + # `m` will be loaded to GPU everytime when sampling starts. + # `m` will be unloaded if necessary. + # `m` will influence Forge's judgement about use GPU memory or + # capacity and decide whether to use module offload to make user's batch size larger. + # Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling. + + if cast_to_unet_dtype: + m.to(self.model.diffusion_model.dtype) + + patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device) + + self.add_extra_model_patcher_during_sampling(patcher) + return + def add_patched_controlnet(self, cnet): cnet.set_previous_controlnet(self.controlnet_linked_list) self.controlnet_linked_list = cnet