Files
ResBench/evaluate/plot_pass.py
jultrishyyy ae1cc41f21 first commit
2025-02-20 20:38:50 +00:00

137 lines
4.3 KiB
Python

import json
import matplotlib.pyplot as plt
import re
import seaborn as sns
import pandas as pd
# --- Utility Functions ---
def compute_module_pass(solution_list, k):
"""
Check the first k solutions for a module.
Return 1 if at least one of them has a "pass" value (after stripping and lowercasing) equal to "true",
otherwise return 0.
"""
for sol in solution_list[:k]:
if sol.get("pass", "").strip().lower() == "true":
return 1
return 0
def compute_pass_at_k_for_modules(modules, k):
"""
Given a list of modules (each module is expected to have a "solutions" list),
compute the fraction of modules that pass@k.
"""
total = len(modules)
if total == 0:
return 0
passed = sum(compute_module_pass(mod["solutions"], k) for mod in modules)
return passed / total
def compute_overall_pass_at_k(llm_data, ks):
"""
Given one LLM's data (a dict mapping category names to lists of modules),
compute the overall pass@k (over all modules in all categories).
Returns a dictionary mapping each k to the pass@k value.
"""
all_modules = []
for cat, modules in llm_data.items():
all_modules.extend(modules)
overall = {}
for k in ks:
overall[k] = compute_pass_at_k_for_modules(all_modules, k)
return overall
def compute_category_pass_at_k(llm_data, ks):
"""
For each category (type) in one LLM, compute pass@k.
Returns a dictionary mapping category names to a dictionary of k -> pass@k.
"""
cat_results = {}
for cat, modules in llm_data.items():
k_dict = {}
for k in ks:
k_dict[k] = compute_pass_at_k_for_modules(modules, k)
cat_results[cat] = k_dict
return cat_results
# --- Main processing and plotting ---
# Choose the k values you want to evaluate pass@k for:
ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
# Load the JSON file.
input_json_file = "solutions.json" # adjust filename if necessary
with open(input_json_file, "r") as f:
data = json.load(f)
# We'll store our computed pass@k results per LLM in a dictionary.
llm_results = {}
for llm, llm_data in data.items():
overall = compute_overall_pass_at_k(llm_data, ks)
categories = compute_category_pass_at_k(llm_data, ks)
llm_results[llm] = {
"overall": overall,
"categories": categories
}
# --- Plot Overall Pass@k for each LLM ---
plt.figure(figsize=(10, 6))
for llm, res in llm_results.items():
plt.plot(ks, [res["overall"][k] for k in ks], marker='o', label=llm)
# plt.xticks(ks) # Ensure all values from 1 to 15 are shown
# plt.xlabel("k", fontsize=14)
# plt.ylabel("Overall Pass@k", fontsize=14)
# plt.title("Overall Pass@k across k for each LLM", fontsize=16) # Larger title
# plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) # Legend outside the plot
# plt.grid(True)
# plt.tight_layout()
# plt.savefig("./figures/overall_pass_at_k.png")
# plt.show()
# --- Plot Per-Category Pass@k for all LLMs, one figure per k ---
# First, determine the union of all categories across LLMs.
# Prepare data for heatmap
category_pass_k = {}
for llm, res in llm_results.items():
for cat, kdict in res["categories"].items():
if cat not in category_pass_k:
category_pass_k[cat] = {}
category_pass_k[cat][llm] = kdict[15] # Using Pass@15
# Convert to DataFrame
df_heatmap = pd.DataFrame.from_dict(category_pass_k).T
for k in ks:
# Convert to DataFrame
df_heatmap = pd.DataFrame.from_dict(category_pass_k).T
# Plot heatmap
plt.figure(figsize=(10, 6))
sns.heatmap(df_heatmap, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
plt.title("Pass@15 Heatmap for Each LLM Across Categories", fontsize=16, fontweight="bold")
plt.xlabel("LLM", fontsize=14, fontweight="bold")
plt.ylabel("Category", fontsize=14, fontweight="bold")
plt.xticks(rotation=45, ha="right", fontsize=12)
plt.yticks(fontsize=12)
plt.tight_layout()
heatmap_path = f"./figures/per_category_pass_k{k}_heatmap.png"
plt.savefig(heatmap_path)
# --- (Optional) Print the computed results ---
print("Overall Pass@k per LLM:")
for llm, res in llm_results.items():
print(f"{llm}: {res['overall']}")
print("\nPer-Category Pass@k per LLM:")
for llm, res in llm_results.items():
print(f"{llm}:")
for cat, kdict in res["categories"].items():
print(f" {cat}: {kdict}")