Files
composable_kernel/script/tools/analyze_build_trace.py
Max Podkorytov 086a1f8861 Add LLM-agnostic Docker and build analysis tools (#3576)
This commit introduces utility tools for building, testing, and analyzing
Composable Kernel. The tools are designed to be LLM-agnostic and can be
used with any AI assistant or directly from the command line.

Tools Added:
============

1. ck-docker - Docker container management
   - Start/stop ROCm-enabled containers
   - Build targets with CMake + Ninja
   - Run tests with gtest filters
   - Auto-detect GPU targets (gfx950, gfx942, etc.)
   - Per-user, per-branch container naming to avoid conflicts

2. ck-build-analysis - Build time profiling
   - Uses Clang's -ftime-trace for compilation analysis
   - Aggregates statistics across multiple trace files
   - Identifies template instantiation bottlenecks
   - Generates detailed Markdown reports with:
     * Compilation phase breakdown
     * Top expensive instantiations
     * Template family analysis
     * Data-driven optimization recommendations
   - Configurable granularity (1µs to 500µs)
   - PEP 723 compliant Python script with auto-dependency management via uv

Key Features:
=============

- LLM-agnostic design (works with any AI assistant)
- Zero-configuration setup with automatic dependency installation
- Comprehensive documentation in script/tools/README*.md
- Security hardening (input validation, no command injection)
- Multi-file trace aggregation for accurate build analysis
- Jinja2-based report generation for customizable output

Implementation:
===============

- script/tools/ck-docker - Main Docker orchestration script
- script/tools/ck-build-analysis - Build analysis orchestration
- script/tools/common.sh - Shared utilities (container mgmt, GPU detection)
- script/tools/analyze_build_trace.py - PEP 723 compliant Python analyzer
- script/tools/templates/ - Jinja2 templates for report generation
- script/tools/README*.md - Comprehensive documentation

Directory Structure:
====================

script/tools/
├── README.md                          # Main overview
├── README_ck-docker.md                # ck-docker documentation
├── README_ck-build-analysis.md        # ck-build-analysis documentation
├── ck-docker                          # Docker orchestration script
├── ck-build-analysis                  # Build analysis orchestration
├── common.sh                          # Shared utilities
├── analyze_build_trace.py             # Python analyzer (PEP 723)
└── templates/
    └── build_analysis_report.md.jinja # Report template

The tools follow Unix philosophy: do one thing well, compose easily,
and work from both CLI and programmatic contexts.
2026-01-15 08:30:23 -08:00

348 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jinja2>=3.0.0",
# ]
# ///
"""
Build Time Analysis Tool for Composable Kernel
Analyzes Clang -ftime-trace output to identify template instantiation
bottlenecks and generate comprehensive build time reports.
"""
import json
import os
import re
import sys
from collections import defaultdict
from datetime import datetime
try:
from jinja2 import Environment, FileSystemLoader
except ImportError:
print("Error: jinja2 is required but not installed.", file=sys.stderr)
print("Install with: apt-get install python3-jinja2", file=sys.stderr)
print("Or with pip: pip install jinja2", file=sys.stderr)
sys.exit(1)
def parse_arguments():
"""Parse command-line arguments."""
if len(sys.argv) < 7:
print(
"Usage: analyze_build_trace.py <trace_files_or_dir> <output_file> <target> <granularity> <build_time> <template_dir>"
)
print(
" trace_files_or_dir: Comma-separated list of trace files OR directory containing .json files"
)
sys.exit(1)
return {
"trace_input": sys.argv[1],
"output_file": sys.argv[2],
"target": sys.argv[3],
"granularity": sys.argv[4],
"build_time": sys.argv[5],
"template_dir": sys.argv[6],
}
def find_trace_files(trace_input):
"""Find all trace files from input (file list, single file, or directory)."""
trace_files = []
# Check if it's a directory
if os.path.isdir(trace_input):
print(f"Scanning directory: {trace_input}")
for root, dirs, files in os.walk(trace_input):
for file in files:
# Include .cpp.json and .hip.json, exclude compile_commands.json and CMake files
if file.endswith((".cpp.json", ".hip.json")) and "CMakeFiles" in root:
trace_files.append(os.path.join(root, file))
trace_files.sort()
# Check if it's a comma-separated list
elif "," in trace_input:
trace_files = [f.strip() for f in trace_input.split(",")]
# Single file
else:
trace_files = [trace_input]
# Filter out non-existent files
valid_files = [f for f in trace_files if os.path.isfile(f)]
if not valid_files:
print(f"Error: No valid trace files found in: {trace_input}", file=sys.stderr)
sys.exit(1)
print(f"Found {len(valid_files)} trace file(s)")
return valid_files
def load_trace_data(trace_files):
"""Load and parse multiple trace JSON files."""
all_data = []
for trace_file in trace_files:
print(f" Loading: {trace_file}")
try:
with open(trace_file, "r") as f:
data = json.load(f)
# Get file basename for tracking
file_name = os.path.basename(trace_file)
all_data.append({"file": file_name, "path": trace_file, "data": data})
except Exception as e:
print(f" Warning: Failed to load {trace_file}: {e}", file=sys.stderr)
return all_data
def process_events(all_trace_data):
"""Process trace events from multiple files and extract statistics."""
print("Processing events from all files...")
template_stats = defaultdict(lambda: {"count": 0, "total_dur": 0})
phase_stats = defaultdict(int)
top_individual = []
file_stats = []
total_events = 0
for trace_info in all_trace_data:
file_name = trace_info["file"]
data = trace_info["data"]
events = data.get("traceEvents", [])
file_template_time = 0
file_event_count = len(events)
total_events += file_event_count
print(f" Processing {file_name}: {file_event_count:,} events")
for event in events:
name = event.get("name", "")
dur = int(event.get("dur", 0)) # Keep as integer microseconds
if name and dur > 0:
phase_stats[name] += dur
if name in ["InstantiateFunction", "InstantiateClass"]:
detail = event.get("args", {}).get("detail", "")
top_individual.append(
{"detail": detail, "dur": dur, "type": name, "file": file_name}
)
file_template_time += dur
# Extract template name (everything before '<' or '(')
match = re.match(r"^([^<(]+)", detail)
if match:
template_name = match.group(1).strip()
# Normalize template names
template_name = re.sub(r"^ck::", "", template_name)
template_name = re.sub(r"^std::", "std::", template_name)
template_stats[template_name]["count"] += 1
template_stats[template_name]["total_dur"] += dur
file_stats.append(
{
"name": file_name,
"events": file_event_count,
"template_time": file_template_time,
}
)
return template_stats, phase_stats, top_individual, file_stats, total_events
def prepare_template_data(template_stats, phase_stats, top_individual, file_stats):
"""Prepare and calculate derived statistics for template rendering."""
print("Sorting data...")
# Sort data
sorted_phases = sorted(phase_stats.items(), key=lambda x: x[1], reverse=True)
top_individual.sort(key=lambda x: x["dur"], reverse=True)
file_stats.sort(key=lambda x: x["template_time"], reverse=True)
# Calculate totals
total_template_time = sum(s["total_dur"] for s in template_stats.values())
total_trace_time = sum(phase_stats.values())
total_inst = sum(s["count"] for s in template_stats.values())
# Prepare templates by time with calculated fields
templates_by_time = []
for name, stats in sorted(
template_stats.items(), key=lambda x: x[1]["total_dur"], reverse=True
):
templates_by_time.append(
(
name,
{
"count": stats["count"],
"total_dur": stats["total_dur"],
"avg": stats["total_dur"] // stats["count"]
if stats["count"] > 0
else 0,
"pct": 100 * stats["total_dur"] / total_template_time
if total_template_time > 0
else 0,
},
)
)
# Prepare templates by count
templates_by_count = []
for name, stats in sorted(
template_stats.items(), key=lambda x: x[1]["count"], reverse=True
):
templates_by_count.append(
(
name,
{
"count": stats["count"],
"total_dur": stats["total_dur"],
"avg": stats["total_dur"] // stats["count"]
if stats["count"] > 0
else 0,
},
)
)
# Add friendly type names to individual instantiations
for inst in top_individual:
inst["inst_type"] = "Func" if inst["type"] == "InstantiateFunction" else "Class"
# Calculate additional metrics
median_count = 0
if len(template_stats) > 0:
median_count = sorted([s["count"] for s in template_stats.values()])[
len(template_stats) // 2
]
top10_pct = 0
if len(templates_by_time) >= 10:
top10_pct = (
100
* sum(s[1]["total_dur"] for s in templates_by_time[:10])
/ total_template_time
)
return {
"sorted_phases": sorted_phases,
"top_individual": top_individual,
"templates_by_time": templates_by_time,
"templates_by_count": templates_by_count,
"total_template_time": total_template_time,
"total_trace_time": total_trace_time,
"total_inst": total_inst,
"median_count": median_count,
"top10_pct": top10_pct,
"unique_families": len(template_stats),
"file_stats": file_stats,
}
def setup_jinja_environment(template_dir):
"""Set up Jinja2 environment with custom filters."""
env = Environment(loader=FileSystemLoader(template_dir))
def format_number(value):
"""Format number with thousand separators."""
return f"{value:,}"
def truncate(value, length):
"""Truncate string to length with ellipsis."""
if len(value) > length:
return value[: length - 3] + "..."
return value
def pad(value, length):
"""Pad string to specified length."""
return f"{value:<{length}}"
def us_to_ms(value):
"""Convert microseconds to milliseconds."""
return value / 1000.0
def us_to_s(value):
"""Convert microseconds to seconds."""
return value / 1000000.0
env.filters["format_number"] = format_number
env.filters["truncate"] = truncate
env.filters["pad"] = pad
env.filters["us_to_ms"] = us_to_ms
env.filters["us_to_s"] = us_to_s
return env
def generate_report(env, data, args, total_events, num_files):
"""Generate the final report using Jinja2 template."""
print("Rendering report with Jinja2...")
template = env.get_template("build_analysis_report.md.jinja")
report_content = template.render(
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
target=args["target"],
granularity=args["granularity"],
build_time=args["build_time"],
total_events=total_events,
num_files=num_files,
total_instantiations=data["total_inst"],
unique_families=data["unique_families"],
total_trace_time=data["total_trace_time"],
total_template_time=data["total_template_time"],
phases=data["sorted_phases"],
top_individual=data["top_individual"],
templates_by_time=data["templates_by_time"],
templates_by_count=data["templates_by_count"],
median_count=data["median_count"],
top10_pct=data["top10_pct"],
file_stats=data["file_stats"],
)
return report_content
def main():
"""Main entry point for the analysis tool."""
args = parse_arguments()
# Find and load trace files
trace_files = find_trace_files(args["trace_input"])
all_trace_data = load_trace_data(trace_files)
# Process events from all files
template_stats, phase_stats, top_individual, file_stats, total_events = (
process_events(all_trace_data)
)
# Prepare template data
data = prepare_template_data(
template_stats, phase_stats, top_individual, file_stats
)
# Setup Jinja2 environment
env = setup_jinja_environment(args["template_dir"])
# Generate report
report_content = generate_report(env, data, args, total_events, len(all_trace_data))
# Write output
with open(args["output_file"], "w") as f:
f.write(report_content)
print(f"Report generated: {args['output_file']}")
print(f"Report size: {len(report_content):,} bytes")
print(f"Analyzed {len(all_trace_data)} file(s) with {total_events:,} total events")
if __name__ == "__main__":
main()