diff --git a/python/CuTeDSL/cutlass/utils/print_latex.py b/python/CuTeDSL/cutlass/utils/print_latex.py index c32fd8caf..ac87201ed 100644 --- a/python/CuTeDSL/cutlass/utils/print_latex.py +++ b/python/CuTeDSL/cutlass/utils/print_latex.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Callable, Union +from typing import Callable, Optional, Union from ..cute import ( Layout, @@ -59,7 +59,10 @@ def tikz_color_tv(tid: int, vid: int) -> str: def print_latex( - x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8 + x: Union[Layout, ComposedLayout], + *, + color: Callable = tikz_color_bwx8, + render_func: Optional[Callable[[str], None]] = None, ) -> None: """ Prints a layout. @@ -67,6 +70,8 @@ def print_latex( :type x: Union[Layout, ComposedLayout] :param color: A function that returns TiKZ colors :type color: Callable + :param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout. + :type render_func: Optional[Callable] """ if not is_static(x): @@ -79,11 +84,20 @@ def print_latex( else: layout = x - print("%% Layout: {}", layout) - print("\\documentclass[convert]{standalone}") - print("\\usepackage{tikz}") - print("\\begin{document}") - print( + latex_output = [] + + def print_or_append(*args): + if render_func is not None: + latex_output.append(" ".join(str(arg) for arg in args)) + else: + print(*args) + + if render_func is None: + print_or_append("%% Layout: {}", layout) + print_or_append("\\documentclass[convert]{standalone}") + print_or_append("\\usepackage{tikz}") + print_or_append("\\begin{document}") + print_or_append( "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]" ) @@ -92,20 +106,24 @@ def print_latex( for m in range(M): for n in range(N): idx = layout((m, n)) - print("\\node[fill=") - print(color(idx)) - print("] at (%d,%d) {%d};\n" % (m, n, idx)) - print( + print_or_append("\\node[fill=") + print_or_append(color(idx)) + print_or_append("] at (%d,%d) {%d};\n" % (m, n, idx)) + print_or_append( "\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N) ) for m in range(M): - print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m)) + print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m)) for n in range(N): - print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n)) + print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n)) ## Footer - print("\\end{tikzpicture}") - print("\\end{document}") + print_or_append("\\end{tikzpicture}") + if render_func is None: + print_or_append("\\end{document}") + + if render_func is not None: + render_func(" ".join(latex_output)) def print_latex_tv( @@ -113,6 +131,7 @@ def print_latex_tv( tile_mn: Union[IntTuple, Layout], *, color: Callable = tikz_color_tv, + render_func: Optional[Callable[[str], None]] = None, ) -> None: """ Prints a tv layout for a tile M N. Everything must be static. @@ -122,17 +141,28 @@ def print_latex_tv( :type tile_mn: Union[IntTuple, Layout] :param color: A function that returns TiKZ colors :type color: Callable + :param render_func: An user provided function to render the latex output, which only includes tikz picture section. If None, it will print to stdout. + :type render_func: Optional[Callable] """ if not is_static(layout_tv) or not is_static(tile_mn): raise ValueError("Layout tv and tile_mn must be static") if rank(layout_tv) != 2: raise ValueError("Require layout_tv to be rank 2") - print("%% Layout TV: {}", layout_tv) - print("\\documentclass[convert]{standalone}") - print("\\usepackage{tikz}") - print("\\begin{document}") - print( + latex_output = [] + + def print_or_append(*args): + if render_func is not None: + latex_output.append(" ".join(str(arg) for arg in args)) + else: + print(*args) + + if render_func is None: + print_or_append("%% Layout TV: {}", layout_tv) + print_or_append("\\documentclass[convert]{standalone}") + print_or_append("\\usepackage{tikz}") + print_or_append("\\begin{document}") + print_or_append( "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n" ) @@ -149,19 +179,23 @@ def print_latex_tv( n = (idx // tile_mn.stride[1]) % tile_mn.shape[1] # type: ignore[operator, union-attr, index] if not filled[m][n]: filled[m][n] = True - print( + print_or_append( "\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n" % (color(tid, vid), m, n, tid, vid) ) - print( + print_or_append( "\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N) ) for m in range(M): - print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m)) + print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m)) for n in range(N): - print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n)) + print_or_append("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n)) ## Footer - print("\\end{tikzpicture}") - print("\\end{document}") + print_or_append("\\end{tikzpicture}") + if render_func is None: + print_or_append("\\end{document}") + + if render_func is not None: + render_func(" ".join(latex_output))