add_extra_torch_module_during_sampling for animatediff

This commit is contained in:
lllyasviel
2024-02-03 01:39:12 -08:00
parent 67c87e6667
commit e3fca95012

View File

@@ -1,4 +1,5 @@
import copy
import torch
from ldm_patched.modules.model_patcher import ModelPatcher
from ldm_patched.modules.sample import convert_cond
@@ -42,6 +43,23 @@ class UnetPatcher(ModelPatcher):
self.extra_model_patchers_during_sampling.append(model_patcher)
return
def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True):
# Use this method to bind an extra torch.nn.Module to this UNet during sampling.
# This model `m` will be delegated to Forge memory management system.
# `m` will be loaded to GPU everytime when sampling starts.
# `m` will be unloaded if necessary.
# `m` will influence Forge's judgement about use GPU memory or
# capacity and decide whether to use module offload to make user's batch size larger.
# Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling.
if cast_to_unet_dtype:
m.to(self.model.diffusion_model.dtype)
patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device)
self.add_extra_model_patcher_during_sampling(patcher)
return
def add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
self.controlnet_linked_list = cnet