Improve analysis script.

This commit is contained in:
Ville Pietilä
2025-07-01 12:28:24 +00:00
parent ef6262782f
commit fbfe57e7c2

View File

@@ -7,6 +7,7 @@ import pandas as pd
import csv
import matplotlib
from collections import defaultdict
import numpy as np
matplotlib.use('Agg') # Use a non-interactive backend
from matplotlib import pyplot as plt
@@ -519,6 +520,69 @@ def plot_subscription_factor_per_instance(kgemm_to_subscription_per_instance, ou
plt.savefig(file_name, dpi=150)
plt.close()
def plot_performance(fixed_split_k_tflops, best_occupancy_split_k_tflops, gemm_m, gemm_n, gemm_k,
arithmetic_intensity, output_dir, suffix, op_name):
"""Plot the performance of fixed split-k vs best occupancy split-k."""
plt.figure(figsize=(12, 8))
# Convert to float for plotting
fixed_split_k_tflops = fixed_split_k_tflops.astype(float).values
best_occupancy_split_k_tflops = best_occupancy_split_k_tflops.astype(float).values
gemm_m_arr = gemm_m.astype(float).values
gemm_n_arr = gemm_n.astype(float).values
gemm_k_arr = gemm_k.astype(float).values
ai_arr = arithmetic_intensity.astype(float).values
perf = (best_occupancy_split_k_tflops / fixed_split_k_tflops) * 100.0
x_values = np.log(gemm_k_arr)
y_values = np.log(gemm_m_arr * gemm_n_arr)
# Heat map with axis gemm_m * gemm_n and gemm_k
scatter = plt.scatter(x_values, y_values,
c=perf,
cmap='coolwarm',
edgecolor='black',
alpha=0.7,
s=40, # Size of the points
norm=plt.Normalize(vmin=50, vmax=150)) # Normalize colors: blue (<100%), red (>100%)
title = op_name if op_name else 'Performance of Best Occupancy Split-K vs Fixed Split-K'
title_size = 14 if op_name else 16
plt.colorbar(label='Performance (%)')
plt.title(title, fontsize=title_size)
plt.xlabel('log(K)', fontsize=14)
plt.ylabel('log(M * N)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
file_name = os.path.join(output_dir, f'performance_heatmap_k_mn{suffix}.png')
plt.savefig(file_name, dpi=150)
print(f"Saved performance heatmap to: {file_name}")
# Heat map with axis log(gemm_k) and log(ai_arr)
y_values = np.log(ai_arr)
plt.figure(figsize=(12, 8))
scatter = plt.scatter(x_values, y_values,
c=perf,
cmap='coolwarm',
edgecolor='black',
alpha=0.7,
s=40, # Size of the points
norm=plt.Normalize(vmin=50, vmax=150)) # Normalize colors: blue (<100%), red (>100%)
plt.colorbar(label='Performance (%)')
plt.title(title, fontsize=title_size)
plt.xlabel('log(K)', fontsize=14)
plt.ylabel('log(Arithmetic Intensity)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
file_name = os.path.join(output_dir, f'performance_heatmap_k_ai{suffix}.png')
plt.savefig(file_name, dpi=150)
print(f"Saved performance heatmap to: {file_name}")
def main():
args = parse_cli_args()
@@ -539,43 +603,30 @@ def main():
best_occupancy_split_k_times = df[4]
best_occupancy_split_k_values = df[5]
else:
fixed_split_k_ops = df[0]
fixed_split_k_times = df[1]
fixed_split_k_values = df[2]
best_occupancy_split_k_ops = df[4]
best_occupancy_split_k_times = df[5]
best_occupancy_split_k_values = df[6]
best_occupancy_subs_factor = df[7]
gemm_k = df[9]
valid_mask1 = df[10] == "SplitKStrategy::FixedSplitK"
valid_mask2 = df[16] == "SplitKStrategy::BestOccupancy"
valid_mask = valid_mask1 & valid_mask2
df_offset = 11
step = 10
max_df_columns = len(df.columns)
kgemm_to_subscription_per_instance = {}
try:
for i in range(df_offset, max_df_columns, step):
columns_of_interest = [i, i+6, i+8] # op, subs_factor, k_gemm columns
subset_df = df[columns_of_interest].copy()
subset_df.columns = ['op', 'subs_factor', 'k_gemm']
subset_df = subset_df.dropna()
# Clean data
op = subset_df['op']
subs_factor = subset_df['subs_factor']
k_gemm = subset_df['k_gemm']
assert len(op) == len(subs_factor) == len(k_gemm), \
f"Length mismatch in columns {i} ({len(op)}), {i + 6} ({len(subs_factor)}), {i + 8} ({len(k_gemm)})"
for j in range(len(op)):
if op.iloc[j] not in kgemm_to_subscription_per_instance:
kgemm_to_subscription_per_instance[op.iloc[j]] = []
kgemm_value = k_gemm.iloc[j]
subs_factor_value = subs_factor.iloc[j]
kgemm_to_subscription_per_instance[op.iloc[j]].append((kgemm_value, subs_factor_value))
except:
print("Cannot parse subscription factor data, skipping this part.")
gemm_m = df[0][valid_mask]
gemm_n = df[1][valid_mask]
gemm_k = df[2][valid_mask]
arithmetic_intensity = df[3][valid_mask]
data_type = df[4][valid_mask]
fixed_split_k_ops = df[5][valid_mask]
fixed_split_k_times = df[6][valid_mask]
fixed_split_k_tflops = df[7][valid_mask]
fixed_split_k_values = df[8][valid_mask]
# 9 - rank
# 10 - strategy
best_occupancy_split_k_ops = df[11][valid_mask]
best_occupancy_split_k_times = df[12][valid_mask]
best_occupancy_split_k_tflops = df[13][valid_mask]
best_occupancy_split_k_values = df[14][valid_mask]
# 15 - rank
# 16 - strategy
#17 - total number of candidate ops.
suffix = f"_{args.label}" if args.label else ""
@@ -649,34 +700,8 @@ def main():
non_standard_counts[val] = non_standard_counts.get(val, 0) + 1
plot_split_k_distribution(non_standard_counts, best_occupancy_split_k_count, args, suffix)
try:
valid_gemm_k_mask = (gemm_k != "N/A") & (~pd.isna(gemm_k))
gemm_k_values = gemm_k[valid_gemm_k_mask].astype(int)
subs_factor_values = best_occupancy_subs_factor[valid_gemm_k_mask].astype(int)
plot_subscription_factor(
gemm_k_values, subs_factor_values, args.output_dir, suffix)
plot_subscription_factor_per_instance(
kgemm_to_subscription_per_instance, args.output_dir, suffix)
for key in kgemm_to_subscription_per_instance:
if key.startswith("Device"):
gemm_k_values = []
subs_factor_values = []
for kgemm, subs_factor in kgemm_to_subscription_per_instance[key]:
if kgemm != "N/A" and not pd.isna(kgemm):
gemm_k_values.append(int(kgemm))
subs_factor_values.append(int(subs_factor))
plot_subscription_factor(
gemm_k_values, subs_factor_values, args.output_dir, suffix, key)
# Print the names of the different instances
print("Instances with subscription factor data:")
for instance in kgemm_to_subscription_per_instance.keys():
print(f" - {instance}: {len(kgemm_to_subscription_per_instance[instance])} data points")
except:
print("Cannot plot subscription factor data, skipping this part.")
plot_performance(fixed_split_k_tflops, best_occupancy_split_k_tflops, gemm_m, gemm_n, gemm_k, arithmetic_intensity, args.output_dir, suffix, op_name)
if __name__ == "__main__":
main()