| import matplotlib as mpl |
|
|
| mpl.use("Agg") |
| import argparse |
| import os |
| import pandas as pd |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| import matplotlib |
| import IPython |
|
|
| font = { |
| "size": 22, |
| } |
| matplotlib.rc("font", **font) |
| sns.set_context("paper", font_scale=2.0) |
|
|
|
|
| def mkdir_if_missing(dst_dir): |
| if not os.path.exists(dst_dir): |
| os.makedirs(dst_dir) |
|
|
|
|
| def save_figure(name, title=""): |
| if len(title) > 0: |
| plt.title(title) |
| plt.tight_layout() |
| print(f"output/output_figures/{name[:30]}") |
| mkdir_if_missing(f"output/output_figures/{name[:30]}") |
| plt.savefig(f"output/output_figures/{name[:30]}/output.png") |
| plt.clf() |
|
|
|
|
| def main(multirun_out, title): |
| dfs = [] |
| suffix = "" |
| run_num = 0 |
|
|
| for rundir in (sorted(multirun_out.split(","))): |
| runpath = os.path.join('output/output_stats', rundir) |
| statspath = os.path.join(runpath, "eval_results.csv") |
| if os.path.exists(statspath): |
| run_num += 1 |
| df = pd.read_csv(statspath) |
| |
| |
| |
| dfs.append(df) |
| else: |
| print("skip:", statspath) |
|
|
| |
| df = pd.concat(dfs) |
| print(df.iloc) |
| title += f" run: {run_num} " |
|
|
| |
| fig, ax = plt.subplots(figsize=(16, 8)) |
| sns_plot = sns.barplot( |
| data=df, x="metric", y="success", hue='model', errorbar=("sd", 1), palette="deep" |
| ) |
|
|
| |
| for container in ax.containers: |
| ax.bar_label(container, label_type="center", fontsize="x-large", fmt="%.2f") |
|
|
| |
| save_figure(f"{multirun_out}_{title}{suffix}", title) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--multirun_out", type=str) |
| parser.add_argument("--title", type=str, default="") |
|
|
| args = parser.parse_args() |
| main(args.multirun_out, args.title) |
|
|