mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 06:00:03 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83dd65f23a | ||
|
|
8ad38d2073 | ||
|
|
6c14f129af | ||
|
|
58dcc97dcf |
17
comfy/ops.py
17
comfy/ops.py
@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
@@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
if want_requant and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
@@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
@@ -1209,6 +1209,30 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
Type = BoundingBoxDict
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None, component: str=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
||||
self.component = component
|
||||
if default is None:
|
||||
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.component:
|
||||
d["component"] = self.component
|
||||
return d
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@@ -2190,5 +2214,6 @@ __all__ = [
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
import base64
|
||||
import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
|
||||
@@ -119,6 +120,13 @@ async def create_image_parts(
|
||||
return image_parts
|
||||
|
||||
|
||||
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
|
||||
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
|
||||
if mime is None:
|
||||
return False
|
||||
return fnmatch(mime.value, pattern)
|
||||
|
||||
|
||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
@@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
for part in candidate.content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
|
||||
if not parts and blocked_reasons:
|
||||
@@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
|
||||
@@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
display_name="Image Crop (Deprecated)",
|
||||
category="image/transform",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
@@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode):
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageCropV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, crop_region) -> IO.NodeOutput:
|
||||
x = crop_region.get("x", 0)
|
||||
y = crop_region.get("y", 0)
|
||||
width = crop_region.get("width", 512)
|
||||
height = crop_region.get("height", 512)
|
||||
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
|
||||
|
||||
|
||||
class BoundingBox(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PrimitiveBoundingBox",
|
||||
display_name="Bounding Box",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[IO.BoundingBox.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, x, y, width, height) -> IO.NodeOutput:
|
||||
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
|
||||
|
||||
|
||||
class RepeatImageBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@@ -632,6 +684,8 @@ class ImagesExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCrop,
|
||||
ImageCropV2,
|
||||
BoundingBox,
|
||||
RepeatImageBatch,
|
||||
ImageFromBatch,
|
||||
ImageAddNoise,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.38.14
|
||||
comfyui-frontend-package==1.39.14
|
||||
comfyui-workflow-templates==0.8.43
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
|
||||
Reference in New Issue
Block a user