[CuTeDSL] Add a render function hook to allow render layout natively (#3135)

* [CuTeDSL] Add a render function hook to allow render layout natively

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

* nit

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>

---------

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
This commit is contained in:
Kaining Zhong
2026-06-26 14:14:55 -05:00
committed by GitHub
parent d4b4b494c3
commit 12ff513cea

View File

@@ -9,7 +9,7 @@
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from typing import Callable, Union
from typing import Callable, Optional, Union
from ..cute import (
Layout,
@@ -59,7 +59,10 @@ def tikz_color_tv(tid: int, vid: int) -> str:
def print_latex(
x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8
x: Union[Layout, ComposedLayout],
*,
color: Callable = tikz_color_bwx8,
render_func: Optional[Callable[[str], None]] = None,
) -> None:
"""
Prints a layout.
@@ -67,6 +70,8 @@ def print_latex(
:type x: Union[Layout, ComposedLayout]
:param color: A function that returns TiKZ colors
:type color: Callable
:param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout.
:type render_func: Optional[Callable]
"""
if not is_static(x):
@@ -79,11 +84,20 @@ def print_latex(
else:
layout = x
print("%% Layout: {}", layout)
print("\\documentclass[convert]{standalone}")
print("\\usepackage{tikz}")
print("\\begin{document}")
print(
latex_output = []
def print_or_append(*args):
if render_func is not None:
latex_output.append(" ".join(str(arg) for arg in args))
else:
print(*args)
if render_func is None:
print_or_append("%% Layout: {}", layout)
print_or_append("\\documentclass[convert]{standalone}")
print_or_append("\\usepackage{tikz}")
print_or_append("\\begin{document}")
print_or_append(
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]"
)
@@ -92,20 +106,24 @@ def print_latex(
for m in range(M):
for n in range(N):
idx = layout((m, n))
print("\\node[fill=")
print(color(idx))
print("] at (%d,%d) {%d};\n" % (m, n, idx))
print(
print_or_append("\\node[fill=")
print_or_append(color(idx))
print_or_append("] at (%d,%d) {%d};\n" % (m, n, idx))
print_or_append(
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
)
for m in range(M):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
for n in range(N):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
## Footer
print("\\end{tikzpicture}")
print("\\end{document}")
print_or_append("\\end{tikzpicture}")
if render_func is None:
print_or_append("\\end{document}")
if render_func is not None:
render_func(" ".join(latex_output))
def print_latex_tv(
@@ -113,6 +131,7 @@ def print_latex_tv(
tile_mn: Union[IntTuple, Layout],
*,
color: Callable = tikz_color_tv,
render_func: Optional[Callable[[str], None]] = None,
) -> None:
"""
Prints a tv layout for a tile M N. Everything must be static.
@@ -122,17 +141,28 @@ def print_latex_tv(
:type tile_mn: Union[IntTuple, Layout]
:param color: A function that returns TiKZ colors
:type color: Callable
:param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout.
:type render_func: Optional[Callable]
"""
if not is_static(layout_tv) or not is_static(tile_mn):
raise ValueError("Layout tv and tile_mn must be static")
if rank(layout_tv) != 2:
raise ValueError("Require layout_tv to be rank 2")
print("%% Layout TV: {}", layout_tv)
print("\\documentclass[convert]{standalone}")
print("\\usepackage{tikz}")
print("\\begin{document}")
print(
latex_output = []
def print_or_append(*args):
if render_func is not None:
latex_output.append(" ".join(str(arg) for arg in args))
else:
print(*args)
if render_func is None:
print_or_append("%% Layout TV: {}", layout_tv)
print_or_append("\\documentclass[convert]{standalone}")
print_or_append("\\usepackage{tikz}")
print_or_append("\\begin{document}")
print_or_append(
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n"
)
@@ -149,19 +179,23 @@ def print_latex_tv(
n = (idx // tile_mn.stride[1]) % tile_mn.shape[1] # type: ignore[operator, union-attr, index]
if not filled[m][n]:
filled[m][n] = True
print(
print_or_append(
"\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n"
% (color(tid, vid), m, n, tid, vid)
)
print(
print_or_append(
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
)
for m in range(M):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
for n in range(N):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
## Footer
print("\\end{tikzpicture}")
print("\\end{document}")
print_or_append("\\end{tikzpicture}")
if render_func is None:
print_or_append("\\end{document}")
if render_func is not None:
render_func(" ".join(latex_output))