Support multiple outputs

This commit is contained in:
pythongosssss
2026-01-24 12:55:06 -08:00
parent 521ca3b5d2
commit 5b0fb64d20

View File

@@ -20,8 +20,9 @@ class SizeModeInput(TypedDict):
height: int
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_IMAGES = 5 # u_image0-4
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
logger = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ except ImportError as e:
raise RuntimeError(f"ModernGL is not available.\n{get_missing_requirements_message()}") from e
# Default NOOP fragment shader that passes through the input image unchanged
# For multiple outputs, use: layout(location = 0) out vec4 fragColor0; etc.
DEFAULT_FRAGMENT_SHADER = """#version 300 es
precision highp float;
@@ -38,10 +40,10 @@ uniform sampler2D u_image0;
uniform vec2 u_resolution;
in vec2 v_texcoord;
out vec4 fragColor;
layout(location = 0) out vec4 fragColor0;
void main() {
fragColor = texture(u_image0, v_texcoord);
fragColor0 = texture(u_image0, v_texcoord);
}
"""
@@ -130,10 +132,10 @@ def _image_to_texture(ctx: moderngl.Context, image: np.ndarray) -> moderngl.Text
return texture
def _texture_to_image(fbo: moderngl.Framebuffer, channels: int = 4) -> np.ndarray:
def _texture_to_image(fbo: moderngl.Framebuffer, attachment: int = 0, channels: int = 4) -> np.ndarray:
width, height = fbo.size
data = fbo.read(components=channels)
data = fbo.read(components=channels, attachment=attachment)
image = np.frombuffer(data, dtype=np.uint8).reshape((height, width, channels))
image = np.ascontiguousarray(np.flipud(image))
@@ -170,11 +172,15 @@ def _render_shader(
height: int,
textures: list[moderngl.Texture],
uniforms: dict[str, int | float],
) -> np.ndarray:
# Create output texture and framebuffer
output_texture = ctx.texture((width, height), 4)
output_texture.filter = (moderngl.LINEAR, moderngl.LINEAR)
fbo = ctx.framebuffer(color_attachments=[output_texture])
) -> list[np.ndarray]:
# Create output textures
output_textures = []
for _ in range(MAX_OUTPUTS):
tex = ctx.texture((width, height), 4)
tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
output_textures.append(tex)
fbo = ctx.framebuffer(color_attachments=output_textures)
# Full-screen quad vertices (position + texcoord)
vertices = np.array([
@@ -212,12 +218,16 @@ def _render_shader(
fbo.clear(0.0, 0.0, 0.0, 1.0)
vao.render(moderngl.TRIANGLE_STRIP)
# Read result
return _texture_to_image(fbo, channels=4)
# Read results from all attachments
results = []
for i in range(MAX_OUTPUTS):
results.append(_texture_to_image(fbo, attachment=i, channels=4))
return results
finally:
vao.release()
vbo.release()
output_texture.release()
for tex in output_textures:
tex.release()
fbo.release()
@@ -311,8 +321,9 @@ class GLSLShader(io.ComfyNode):
category="image/shader",
description=(
f"Apply GLSL fragment shaders to images. "
f"Uniforms: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}."
f"Inputs: u_image0-{MAX_IMAGES-1} (sampler2D), u_resolution (vec2), "
f"u_float0-{MAX_UNIFORMS-1}, u_int0-{MAX_UNIFORMS-1}. "
f"Outputs: layout(location = 0-{MAX_OUTPUTS-1}) out vec4 fragColor0-{MAX_OUTPUTS-1}."
),
inputs=[
io.String.Input(
@@ -343,7 +354,10 @@ class GLSLShader(io.ComfyNode):
io.Autogrow.Input("ints", template=int_template),
],
outputs=[
io.Image.Output(display_name="IMAGE"),
io.Image.Output(display_name="IMAGE0"),
io.Image.Output(display_name="IMAGE1"),
io.Image.Output(display_name="IMAGE2"),
io.Image.Output(display_name="IMAGE3"),
],
)
@@ -375,17 +389,22 @@ class GLSLShader(io.ComfyNode):
with _gl_context(force_software=args.cpu) as ctx:
with _shader_program(ctx, fragment_shader) as program:
output_images = []
# Collect outputs for each render target across all batches
all_outputs: list[list[torch.Tensor]] = [[] for _ in range(MAX_OUTPUTS)]
for b in range(batch_size):
with _textures_context(ctx, image_list, b) as textures:
result = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
output_images.append(torch.from_numpy(result))
results = _render_shader(ctx, program, out_width, out_height, textures, uniforms)
for i, result in enumerate(results):
all_outputs[i].append(torch.from_numpy(result))
output_batch = torch.stack(output_images, dim=0)
if output_batch.shape[-1] == 4:
output_batch = output_batch[:, :, :, :3]
# Stack batches for each output
output_values = []
for i in range(MAX_OUTPUTS):
output_batch = torch.stack(all_outputs[i], dim=0)
output_values.append(output_batch)
return io.NodeOutput(output_batch, ui=cls._build_ui_output(image_list, output_batch))
return io.NodeOutput(*output_values, ui=cls._build_ui_output(image_list, output_values[0]))
@classmethod
def _build_ui_output(cls, image_list: list[torch.Tensor], output_batch: torch.Tensor) -> dict[str, list]: