mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-01 03:14:05 +00:00
feat: Support SDPose-OOD (#12661)
This commit is contained in:
@@ -18,6 +18,8 @@ import comfy.patcher_extension
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
from ..sdpose import HeatmapHead
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
@@ -441,6 +443,7 @@ class UNetModel(nn.Module):
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
heatmap_head=False,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@@ -827,6 +830,9 @@ class UNetModel(nn.Module):
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
if heatmap_head:
|
||||
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
|
||||
Reference in New Issue
Block a user