mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
* chore: update copyright header for misc files * fix: typo in kernel resulting in ci failure
228 lines
6.4 KiB
Python
228 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
Script to convert all mermaid diagrams in CK Tile docs to SVGs.
|
|
This script:
|
|
1. Finds all mermaid blocks in RST files
|
|
2. Converts them to SVG using mmdc
|
|
3. Updates RST files to use SVG images with commented mermaid source
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# Configuration
|
|
DOCS_DIR = Path(__file__).parent
|
|
DIAGRAMS_DIR = DOCS_DIR / "diagrams"
|
|
RST_FILES = [
|
|
"convolution_example.rst",
|
|
"encoding_internals.rst",
|
|
"lds_index_swapping.rst",
|
|
"space_filling_curve.rst",
|
|
"sweep_tile.rst",
|
|
"tensor_coordinates.rst",
|
|
"thread_mapping.rst",
|
|
"static_distributed_tensor.rst",
|
|
"load_store_traits.rst",
|
|
"tile_window.rst",
|
|
"transforms.rst",
|
|
"descriptors.rst",
|
|
"coordinate_movement.rst",
|
|
"adaptors.rst",
|
|
"introduction_motivation.rst",
|
|
"buffer_views.rst",
|
|
"tensor_views.rst",
|
|
"coordinate_systems.rst",
|
|
"tile_distribution.rst",
|
|
]
|
|
|
|
# Pattern to find mermaid blocks (can be indented with 3 spaces for commented blocks)
|
|
MERMAID_PATTERN = re.compile(
|
|
r"^(?: )?\.\. mermaid::\s*\n((?:(?:\n| .*))*)", re.MULTILINE
|
|
)
|
|
|
|
|
|
def extract_mermaid_content(block):
|
|
"""Extract the actual mermaid code from the block, removing RST indentation."""
|
|
lines = block.split("\n")
|
|
# Remove the leading spaces (RST indentation)
|
|
content_lines = []
|
|
for line in lines:
|
|
if line.startswith(" "):
|
|
content_lines.append(line[3:]) # Remove 3 spaces
|
|
elif line.strip() == "":
|
|
content_lines.append("")
|
|
return "\n".join(content_lines).strip()
|
|
|
|
|
|
def generate_diagram_name(file_path, diagram_index, total_in_file):
|
|
"""Generate a descriptive name for the diagram."""
|
|
base_name = file_path.stem
|
|
if total_in_file == 1:
|
|
return f"{base_name}.svg"
|
|
else:
|
|
return f"{base_name}_{diagram_index + 1}.svg"
|
|
|
|
|
|
def convert_mermaid_to_svg(mermaid_code, output_path):
|
|
"""Convert mermaid code to SVG using mmdc."""
|
|
# Create a temporary file for the mermaid code
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".mmd", delete=False, encoding="utf-8"
|
|
) as tmp:
|
|
tmp.write(mermaid_code)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Run mmdc to convert to SVG (use shell=True on Windows for .cmd files)
|
|
subprocess.run(
|
|
[
|
|
"mmdc",
|
|
"-i",
|
|
tmp_path,
|
|
"-o",
|
|
str(output_path),
|
|
"-t",
|
|
"neutral",
|
|
"-b",
|
|
"transparent",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
check=True,
|
|
shell=True, # Required for Windows .cmd files
|
|
)
|
|
print(f" ✓ Generated: {output_path.name}")
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
print(f" ✗ Error converting diagram: {e.stderr}")
|
|
return False
|
|
finally:
|
|
# Clean up temp file
|
|
os.unlink(tmp_path)
|
|
|
|
|
|
def update_rst_file(file_path, diagrams_info):
|
|
"""Update RST file to replace mermaid blocks with commented source + image reference."""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# Sort diagrams by position (reverse order to maintain positions)
|
|
diagrams_info.sort(key=lambda x: x["position"], reverse=True)
|
|
|
|
for info in diagrams_info:
|
|
# Find the mermaid block
|
|
match = info["match"]
|
|
start_pos = match.start()
|
|
end_pos = match.end()
|
|
|
|
# Create the replacement text
|
|
mermaid_block = match.group(0)
|
|
|
|
# Create commented mermaid block
|
|
commented_lines = [
|
|
".. ",
|
|
" Original mermaid diagram (edit here, then run update_diagrams.py)",
|
|
" ",
|
|
]
|
|
for line in mermaid_block.split("\n"):
|
|
commented_lines.append(f" {line}")
|
|
|
|
# Add image reference
|
|
svg_rel_path = f"diagrams/{info['svg_name']}"
|
|
image_block = [
|
|
"",
|
|
f".. image:: {svg_rel_path}",
|
|
" :alt: Diagram",
|
|
" :align: center",
|
|
"",
|
|
]
|
|
|
|
replacement = "\n".join(commented_lines + image_block)
|
|
|
|
# Replace in content
|
|
content = content[:start_pos] + replacement + content[end_pos:]
|
|
|
|
# Write back
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(content)
|
|
|
|
print(f" ✓ Updated: {file_path.name}")
|
|
|
|
|
|
def process_file(file_path):
|
|
"""Process a single RST file."""
|
|
print(f"\nProcessing {file_path.name}...")
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
content = f.read()
|
|
|
|
# Find all mermaid blocks
|
|
matches = list(MERMAID_PATTERN.finditer(content))
|
|
|
|
if not matches:
|
|
print(" No mermaid diagrams found.")
|
|
return
|
|
|
|
print(f" Found {len(matches)} diagram(s)")
|
|
|
|
diagrams_info = []
|
|
|
|
# Process each mermaid block
|
|
for idx, match in enumerate(matches):
|
|
mermaid_content = extract_mermaid_content(match.group(1))
|
|
svg_name = generate_diagram_name(file_path, idx, len(matches))
|
|
svg_path = DIAGRAMS_DIR / svg_name
|
|
|
|
# Convert to SVG
|
|
if convert_mermaid_to_svg(mermaid_content, svg_path):
|
|
diagrams_info.append(
|
|
{"match": match, "svg_name": svg_name, "position": match.start()}
|
|
)
|
|
|
|
# Update the RST file
|
|
if diagrams_info:
|
|
update_rst_file(file_path, diagrams_info)
|
|
|
|
|
|
def main():
|
|
"""Main function."""
|
|
print("CK Tile Mermaid to SVG Converter")
|
|
print("=" * 50)
|
|
|
|
# Verify mmdc is available
|
|
try:
|
|
subprocess.run(
|
|
["mmdc", "--version"], capture_output=True, check=True, shell=True
|
|
)
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
print("Error: mermaid-cli (mmdc) not found. Please install it:")
|
|
print(" npm install -g @mermaid-js/mermaid-cli")
|
|
return 1
|
|
|
|
# Ensure diagrams directory exists
|
|
DIAGRAMS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Process each file
|
|
for rst_file in RST_FILES:
|
|
file_path = DOCS_DIR / rst_file
|
|
if file_path.exists():
|
|
process_file(file_path)
|
|
else:
|
|
print(f"\n⚠ Warning: {rst_file} not found")
|
|
|
|
print("\n" + "=" * 50)
|
|
print("✓ Conversion complete!")
|
|
print(f"SVG files saved to: {DIAGRAMS_DIR}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|