Files
composable_kernel/docs/conceptual/ck_tile/convert_mermaid_to_svg.py
Aviral Goel 4dcc3e59c1 chore: update copyright header for misc files (#3402)
* chore: update copyright header for misc files

* fix: typo in kernel resulting in ci failure
2025-12-11 08:25:29 -08:00

228 lines
6.4 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Script to convert all mermaid diagrams in CK Tile docs to SVGs.
This script:
1. Finds all mermaid blocks in RST files
2. Converts them to SVG using mmdc
3. Updates RST files to use SVG images with commented mermaid source
"""
import os
import re
import subprocess
import tempfile
from pathlib import Path
# Configuration
DOCS_DIR = Path(__file__).parent
DIAGRAMS_DIR = DOCS_DIR / "diagrams"
RST_FILES = [
"convolution_example.rst",
"encoding_internals.rst",
"lds_index_swapping.rst",
"space_filling_curve.rst",
"sweep_tile.rst",
"tensor_coordinates.rst",
"thread_mapping.rst",
"static_distributed_tensor.rst",
"load_store_traits.rst",
"tile_window.rst",
"transforms.rst",
"descriptors.rst",
"coordinate_movement.rst",
"adaptors.rst",
"introduction_motivation.rst",
"buffer_views.rst",
"tensor_views.rst",
"coordinate_systems.rst",
"tile_distribution.rst",
]
# Pattern to find mermaid blocks (can be indented with 3 spaces for commented blocks)
MERMAID_PATTERN = re.compile(
r"^(?: )?\.\. mermaid::\s*\n((?:(?:\n| .*))*)", re.MULTILINE
)
def extract_mermaid_content(block):
"""Extract the actual mermaid code from the block, removing RST indentation."""
lines = block.split("\n")
# Remove the leading spaces (RST indentation)
content_lines = []
for line in lines:
if line.startswith(" "):
content_lines.append(line[3:]) # Remove 3 spaces
elif line.strip() == "":
content_lines.append("")
return "\n".join(content_lines).strip()
def generate_diagram_name(file_path, diagram_index, total_in_file):
"""Generate a descriptive name for the diagram."""
base_name = file_path.stem
if total_in_file == 1:
return f"{base_name}.svg"
else:
return f"{base_name}_{diagram_index + 1}.svg"
def convert_mermaid_to_svg(mermaid_code, output_path):
"""Convert mermaid code to SVG using mmdc."""
# Create a temporary file for the mermaid code
with tempfile.NamedTemporaryFile(
mode="w", suffix=".mmd", delete=False, encoding="utf-8"
) as tmp:
tmp.write(mermaid_code)
tmp_path = tmp.name
try:
# Run mmdc to convert to SVG (use shell=True on Windows for .cmd files)
subprocess.run(
[
"mmdc",
"-i",
tmp_path,
"-o",
str(output_path),
"-t",
"neutral",
"-b",
"transparent",
],
capture_output=True,
text=True,
check=True,
shell=True, # Required for Windows .cmd files
)
print(f" ✓ Generated: {output_path.name}")
return True
except subprocess.CalledProcessError as e:
print(f" ✗ Error converting diagram: {e.stderr}")
return False
finally:
# Clean up temp file
os.unlink(tmp_path)
def update_rst_file(file_path, diagrams_info):
"""Update RST file to replace mermaid blocks with commented source + image reference."""
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Sort diagrams by position (reverse order to maintain positions)
diagrams_info.sort(key=lambda x: x["position"], reverse=True)
for info in diagrams_info:
# Find the mermaid block
match = info["match"]
start_pos = match.start()
end_pos = match.end()
# Create the replacement text
mermaid_block = match.group(0)
# Create commented mermaid block
commented_lines = [
".. ",
" Original mermaid diagram (edit here, then run update_diagrams.py)",
" ",
]
for line in mermaid_block.split("\n"):
commented_lines.append(f" {line}")
# Add image reference
svg_rel_path = f"diagrams/{info['svg_name']}"
image_block = [
"",
f".. image:: {svg_rel_path}",
" :alt: Diagram",
" :align: center",
"",
]
replacement = "\n".join(commented_lines + image_block)
# Replace in content
content = content[:start_pos] + replacement + content[end_pos:]
# Write back
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
print(f" ✓ Updated: {file_path.name}")
def process_file(file_path):
"""Process a single RST file."""
print(f"\nProcessing {file_path.name}...")
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Find all mermaid blocks
matches = list(MERMAID_PATTERN.finditer(content))
if not matches:
print(" No mermaid diagrams found.")
return
print(f" Found {len(matches)} diagram(s)")
diagrams_info = []
# Process each mermaid block
for idx, match in enumerate(matches):
mermaid_content = extract_mermaid_content(match.group(1))
svg_name = generate_diagram_name(file_path, idx, len(matches))
svg_path = DIAGRAMS_DIR / svg_name
# Convert to SVG
if convert_mermaid_to_svg(mermaid_content, svg_path):
diagrams_info.append(
{"match": match, "svg_name": svg_name, "position": match.start()}
)
# Update the RST file
if diagrams_info:
update_rst_file(file_path, diagrams_info)
def main():
"""Main function."""
print("CK Tile Mermaid to SVG Converter")
print("=" * 50)
# Verify mmdc is available
try:
subprocess.run(
["mmdc", "--version"], capture_output=True, check=True, shell=True
)
except (subprocess.CalledProcessError, FileNotFoundError):
print("Error: mermaid-cli (mmdc) not found. Please install it:")
print(" npm install -g @mermaid-js/mermaid-cli")
return 1
# Ensure diagrams directory exists
DIAGRAMS_DIR.mkdir(parents=True, exist_ok=True)
# Process each file
for rst_file in RST_FILES:
file_path = DOCS_DIR / rst_file
if file_path.exists():
process_file(file_path)
else:
print(f"\n⚠ Warning: {rst_file} not found")
print("\n" + "=" * 50)
print("✓ Conversion complete!")
print(f"SVG files saved to: {DIAGRAMS_DIR}")
return 0
if __name__ == "__main__":
exit(main())