diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index c9b2a84d9..96ee1a0f8 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -40,23 +40,13 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou inverse_mask = torch.ones_like(mask) - mask - source_rgb = source[:, :3, :visible_height, :visible_width] - dest_slice = destination[..., top:bottom, left:right] - - if destination.shape[1] == 4: - if torch.max(dest_slice) == 0: - destination[:, :3, top:bottom, left:right] = source_rgb - destination[:, 3:4, top:bottom, left:right] = mask - else: - destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3]) - destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4]) - else: - source_portion = mask * source_rgb - destination_portion = inverse_mask * dest_slice - destination[..., top:bottom, left:right] = source_portion + destination_portion + source_portion = mask * source[..., :visible_height, :visible_width] + destination_portion = inverse_mask * destination[..., top:bottom, left:right] + destination[..., top:bottom, left:right] = source_portion + destination_portion return destination + class LatentCompositeMasked(IO.ComfyNode): @classmethod def define_schema(cls): @@ -95,23 +85,18 @@ class ImageCompositeMasked(IO.ComfyNode): display_name="Image Composite Masked", category="image", inputs=[ + IO.Image.Input("destination"), IO.Image.Input("source"), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Boolean.Input("resize_source", default=False), - IO.Image.Input("destination", optional=True), IO.Mask.Input("mask", optional=True), ], outputs=[IO.Image.Output()], ) @classmethod - def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput: - if destination is None: # transparent rgba - B, H, W, C = source.shape - destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device) - if C == 3: - source = torch.nn.functional.pad(source, (0, 1), value=1.0) + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)