mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-28 18:37:05 +00:00
[CuTeDSL] Add a render function hook to allow render layout natively (#3135)
* [CuTeDSL] Add a render function hook to allow render layout natively Signed-off-by: Kaining Zhong <kainingz@nvidia.com> * nit Signed-off-by: Kaining Zhong <kainingz@nvidia.com> --------- Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from ..cute import (
|
||||
Layout,
|
||||
@@ -59,7 +59,10 @@ def tikz_color_tv(tid: int, vid: int) -> str:
|
||||
|
||||
|
||||
def print_latex(
|
||||
x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8
|
||||
x: Union[Layout, ComposedLayout],
|
||||
*,
|
||||
color: Callable = tikz_color_bwx8,
|
||||
render_func: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Prints a layout.
|
||||
@@ -67,6 +70,8 @@ def print_latex(
|
||||
:type x: Union[Layout, ComposedLayout]
|
||||
:param color: A function that returns TiKZ colors
|
||||
:type color: Callable
|
||||
:param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout.
|
||||
:type render_func: Optional[Callable]
|
||||
"""
|
||||
|
||||
if not is_static(x):
|
||||
@@ -79,11 +84,20 @@ def print_latex(
|
||||
else:
|
||||
layout = x
|
||||
|
||||
print("%% Layout: {}", layout)
|
||||
print("\\documentclass[convert]{standalone}")
|
||||
print("\\usepackage{tikz}")
|
||||
print("\\begin{document}")
|
||||
print(
|
||||
latex_output = []
|
||||
|
||||
def print_or_append(*args):
|
||||
if render_func is not None:
|
||||
latex_output.append(" ".join(str(arg) for arg in args))
|
||||
else:
|
||||
print(*args)
|
||||
|
||||
if render_func is None:
|
||||
print_or_append("%% Layout: {}", layout)
|
||||
print_or_append("\\documentclass[convert]{standalone}")
|
||||
print_or_append("\\usepackage{tikz}")
|
||||
print_or_append("\\begin{document}")
|
||||
print_or_append(
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]"
|
||||
)
|
||||
|
||||
@@ -92,20 +106,24 @@ def print_latex(
|
||||
for m in range(M):
|
||||
for n in range(N):
|
||||
idx = layout((m, n))
|
||||
print("\\node[fill=")
|
||||
print(color(idx))
|
||||
print("] at (%d,%d) {%d};\n" % (m, n, idx))
|
||||
print(
|
||||
print_or_append("\\node[fill=")
|
||||
print_or_append(color(idx))
|
||||
print_or_append("] at (%d,%d) {%d};\n" % (m, n, idx))
|
||||
print_or_append(
|
||||
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
|
||||
)
|
||||
for m in range(M):
|
||||
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
|
||||
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
|
||||
for n in range(N):
|
||||
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
|
||||
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
|
||||
|
||||
## Footer
|
||||
print("\\end{tikzpicture}")
|
||||
print("\\end{document}")
|
||||
print_or_append("\\end{tikzpicture}")
|
||||
if render_func is None:
|
||||
print_or_append("\\end{document}")
|
||||
|
||||
if render_func is not None:
|
||||
render_func(" ".join(latex_output))
|
||||
|
||||
|
||||
def print_latex_tv(
|
||||
@@ -113,6 +131,7 @@ def print_latex_tv(
|
||||
tile_mn: Union[IntTuple, Layout],
|
||||
*,
|
||||
color: Callable = tikz_color_tv,
|
||||
render_func: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Prints a tv layout for a tile M N. Everything must be static.
|
||||
@@ -122,17 +141,28 @@ def print_latex_tv(
|
||||
:type tile_mn: Union[IntTuple, Layout]
|
||||
:param color: A function that returns TiKZ colors
|
||||
:type color: Callable
|
||||
:param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout.
|
||||
:type render_func: Optional[Callable]
|
||||
"""
|
||||
if not is_static(layout_tv) or not is_static(tile_mn):
|
||||
raise ValueError("Layout tv and tile_mn must be static")
|
||||
if rank(layout_tv) != 2:
|
||||
raise ValueError("Require layout_tv to be rank 2")
|
||||
|
||||
print("%% Layout TV: {}", layout_tv)
|
||||
print("\\documentclass[convert]{standalone}")
|
||||
print("\\usepackage{tikz}")
|
||||
print("\\begin{document}")
|
||||
print(
|
||||
latex_output = []
|
||||
|
||||
def print_or_append(*args):
|
||||
if render_func is not None:
|
||||
latex_output.append(" ".join(str(arg) for arg in args))
|
||||
else:
|
||||
print(*args)
|
||||
|
||||
if render_func is None:
|
||||
print_or_append("%% Layout TV: {}", layout_tv)
|
||||
print_or_append("\\documentclass[convert]{standalone}")
|
||||
print_or_append("\\usepackage{tikz}")
|
||||
print_or_append("\\begin{document}")
|
||||
print_or_append(
|
||||
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n"
|
||||
)
|
||||
|
||||
@@ -149,19 +179,23 @@ def print_latex_tv(
|
||||
n = (idx // tile_mn.stride[1]) % tile_mn.shape[1] # type: ignore[operator, union-attr, index]
|
||||
if not filled[m][n]:
|
||||
filled[m][n] = True
|
||||
print(
|
||||
print_or_append(
|
||||
"\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n"
|
||||
% (color(tid, vid), m, n, tid, vid)
|
||||
)
|
||||
|
||||
print(
|
||||
print_or_append(
|
||||
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
|
||||
)
|
||||
for m in range(M):
|
||||
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
|
||||
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
|
||||
for n in range(N):
|
||||
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
|
||||
print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))
|
||||
|
||||
## Footer
|
||||
print("\\end{tikzpicture}")
|
||||
print("\\end{document}")
|
||||
print_or_append("\\end{tikzpicture}")
|
||||
if render_func is None:
|
||||
print_or_append("\\end{document}")
|
||||
|
||||
if render_func is not None:
|
||||
render_func(" ".join(latex_output))
|
||||
|
||||
Reference in New Issue
Block a user