first commit
This commit is contained in:
136
evaluate/plot_pass.py
Normal file
136
evaluate/plot_pass.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import re
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
|
||||
# --- Utility Functions ---
|
||||
|
||||
def compute_module_pass(solution_list, k):
|
||||
"""
|
||||
Check the first k solutions for a module.
|
||||
Return 1 if at least one of them has a "pass" value (after stripping and lowercasing) equal to "true",
|
||||
otherwise return 0.
|
||||
"""
|
||||
for sol in solution_list[:k]:
|
||||
if sol.get("pass", "").strip().lower() == "true":
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def compute_pass_at_k_for_modules(modules, k):
|
||||
"""
|
||||
Given a list of modules (each module is expected to have a "solutions" list),
|
||||
compute the fraction of modules that pass@k.
|
||||
"""
|
||||
total = len(modules)
|
||||
if total == 0:
|
||||
return 0
|
||||
passed = sum(compute_module_pass(mod["solutions"], k) for mod in modules)
|
||||
return passed / total
|
||||
|
||||
def compute_overall_pass_at_k(llm_data, ks):
|
||||
"""
|
||||
Given one LLM's data (a dict mapping category names to lists of modules),
|
||||
compute the overall pass@k (over all modules in all categories).
|
||||
Returns a dictionary mapping each k to the pass@k value.
|
||||
"""
|
||||
all_modules = []
|
||||
for cat, modules in llm_data.items():
|
||||
all_modules.extend(modules)
|
||||
overall = {}
|
||||
for k in ks:
|
||||
overall[k] = compute_pass_at_k_for_modules(all_modules, k)
|
||||
return overall
|
||||
|
||||
def compute_category_pass_at_k(llm_data, ks):
|
||||
"""
|
||||
For each category (type) in one LLM, compute pass@k.
|
||||
Returns a dictionary mapping category names to a dictionary of k -> pass@k.
|
||||
"""
|
||||
cat_results = {}
|
||||
for cat, modules in llm_data.items():
|
||||
k_dict = {}
|
||||
for k in ks:
|
||||
k_dict[k] = compute_pass_at_k_for_modules(modules, k)
|
||||
cat_results[cat] = k_dict
|
||||
return cat_results
|
||||
|
||||
# --- Main processing and plotting ---
|
||||
|
||||
# Choose the k values you want to evaluate pass@k for:
|
||||
ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
|
||||
# Load the JSON file.
|
||||
input_json_file = "solutions.json" # adjust filename if necessary
|
||||
with open(input_json_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# We'll store our computed pass@k results per LLM in a dictionary.
|
||||
llm_results = {}
|
||||
for llm, llm_data in data.items():
|
||||
overall = compute_overall_pass_at_k(llm_data, ks)
|
||||
categories = compute_category_pass_at_k(llm_data, ks)
|
||||
llm_results[llm] = {
|
||||
"overall": overall,
|
||||
"categories": categories
|
||||
}
|
||||
|
||||
# --- Plot Overall Pass@k for each LLM ---
|
||||
plt.figure(figsize=(10, 6))
|
||||
for llm, res in llm_results.items():
|
||||
plt.plot(ks, [res["overall"][k] for k in ks], marker='o', label=llm)
|
||||
|
||||
# plt.xticks(ks) # Ensure all values from 1 to 15 are shown
|
||||
# plt.xlabel("k", fontsize=14)
|
||||
# plt.ylabel("Overall Pass@k", fontsize=14)
|
||||
# plt.title("Overall Pass@k across k for each LLM", fontsize=16) # Larger title
|
||||
# plt.legend(loc="upper left", bbox_to_anchor=(1, 1)) # Legend outside the plot
|
||||
# plt.grid(True)
|
||||
# plt.tight_layout()
|
||||
# plt.savefig("./figures/overall_pass_at_k.png")
|
||||
# plt.show()
|
||||
|
||||
|
||||
# --- Plot Per-Category Pass@k for all LLMs, one figure per k ---
|
||||
# First, determine the union of all categories across LLMs.
|
||||
# Prepare data for heatmap
|
||||
category_pass_k = {}
|
||||
for llm, res in llm_results.items():
|
||||
for cat, kdict in res["categories"].items():
|
||||
if cat not in category_pass_k:
|
||||
category_pass_k[cat] = {}
|
||||
category_pass_k[cat][llm] = kdict[15] # Using Pass@15
|
||||
|
||||
# Convert to DataFrame
|
||||
df_heatmap = pd.DataFrame.from_dict(category_pass_k).T
|
||||
|
||||
|
||||
for k in ks:
|
||||
# Convert to DataFrame
|
||||
df_heatmap = pd.DataFrame.from_dict(category_pass_k).T
|
||||
|
||||
# Plot heatmap
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.heatmap(df_heatmap, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
|
||||
|
||||
plt.title("Pass@15 Heatmap for Each LLM Across Categories", fontsize=16, fontweight="bold")
|
||||
plt.xlabel("LLM", fontsize=14, fontweight="bold")
|
||||
plt.ylabel("Category", fontsize=14, fontweight="bold")
|
||||
|
||||
plt.xticks(rotation=45, ha="right", fontsize=12)
|
||||
plt.yticks(fontsize=12)
|
||||
|
||||
plt.tight_layout()
|
||||
heatmap_path = f"./figures/per_category_pass_k{k}_heatmap.png"
|
||||
plt.savefig(heatmap_path)
|
||||
|
||||
# --- (Optional) Print the computed results ---
|
||||
print("Overall Pass@k per LLM:")
|
||||
for llm, res in llm_results.items():
|
||||
print(f"{llm}: {res['overall']}")
|
||||
|
||||
print("\nPer-Category Pass@k per LLM:")
|
||||
for llm, res in llm_results.items():
|
||||
print(f"{llm}:")
|
||||
for cat, kdict in res["categories"].items():
|
||||
print(f" {cat}: {kdict}")
|
||||
Reference in New Issue
Block a user