mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-10 18:09:58 +00:00
add_extra_torch_module_during_sampling for animatediff
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from ldm_patched.modules.sample import convert_cond
|
||||
@@ -42,6 +43,23 @@ class UnetPatcher(ModelPatcher):
|
||||
self.extra_model_patchers_during_sampling.append(model_patcher)
|
||||
return
|
||||
|
||||
def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True):
|
||||
# Use this method to bind an extra torch.nn.Module to this UNet during sampling.
|
||||
# This model `m` will be delegated to Forge memory management system.
|
||||
# `m` will be loaded to GPU everytime when sampling starts.
|
||||
# `m` will be unloaded if necessary.
|
||||
# `m` will influence Forge's judgement about use GPU memory or
|
||||
# capacity and decide whether to use module offload to make user's batch size larger.
|
||||
# Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling.
|
||||
|
||||
if cast_to_unet_dtype:
|
||||
m.to(self.model.diffusion_model.dtype)
|
||||
|
||||
patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
self.add_extra_model_patcher_during_sampling(patcher)
|
||||
return
|
||||
|
||||
def add_patched_controlnet(self, cnet):
|
||||
cnet.set_previous_controlnet(self.controlnet_linked_list)
|
||||
self.controlnet_linked_list = cnet
|
||||
|
||||
Reference in New Issue
Block a user