| |
| |
|
|
| import os |
| import glob |
| import json |
| import re |
| import matplotlib.pyplot as plt |
|
|
| |
| BASE_DIR = "/pfs/lichenyi/work/evaluation" |
|
|
| def collect_accuracies(base_dir: str): |
| """ |
| 从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中 |
| 读取 summary.accuracy,返回两个 dict: |
| in_acc[step] = accuracy |
| ood_acc[step] = accuracy |
| """ |
| pattern = os.path.join(base_dir, "valid_score_*.json") |
| files = glob.glob(pattern) |
|
|
| in_acc = {} |
| ood_acc = {} |
|
|
| |
| regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json") |
|
|
| for path in sorted(files): |
| fname = os.path.basename(path) |
| m = regex.match(fname) |
| if not m: |
| continue |
|
|
| split = m.group(1) |
| step = int(m.group(2)) |
|
|
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| acc = data.get("summary", {}).get("accuracy", None) |
| if acc is None: |
| continue |
|
|
| if split == "in": |
| in_acc[step] = acc |
| else: |
| ood_acc[step] = acc |
|
|
| return in_acc, ood_acc |
|
|
|
|
| def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"): |
| """ |
| 根据 in_acc 和 ood_acc 画图并保存为 out_path。 |
| in_acc / ood_acc: dict[int, float] |
| """ |
| plt.figure(figsize=(8, 5)) |
|
|
| |
| if in_acc: |
| steps_in = sorted(in_acc.keys()) |
| vals_in = [in_acc[s] for s in steps_in] |
| plt.plot(steps_in, vals_in, marker="o", label="in (ID)") |
|
|
| |
| if ood_acc: |
| steps_ood = sorted(ood_acc.keys()) |
| vals_ood = [ood_acc[s] for s in steps_ood] |
| plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)") |
|
|
| plt.xlabel("checkpoint / step") |
| plt.ylabel("accuracy") |
| plt.title("Validation Accuracy (in vs ood)") |
| plt.grid(True, linestyle=":") |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig(out_path, dpi=300) |
| |
| |
|
|
|
|
| def main(): |
| in_acc, ood_acc = collect_accuracies(BASE_DIR) |
| print("in-domain checkpoints and accuracies:", in_acc) |
| print("ood checkpoints and accuracies:", ood_acc) |
| plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png")) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|