| """Show count, mean and median loss per token and label.""" |
| import streamlit as st |
|
|
| from src.subpages.page import Context, Page |
| from src.utils import AgGrid, aggrid_interactive_table |
|
|
|
|
| @st.cache |
| def get_loss_by_token(df_tokens): |
| return ( |
| df_tokens.groupby("tokens")[["losses"]] |
| .agg(["count", "mean", "median", "sum"]) |
| .droplevel(level=0, axis=1) |
| .sort_values(by="sum", ascending=False) |
| .reset_index() |
| ) |
|
|
|
|
| @st.cache |
| def get_loss_by_label(df_tokens): |
| return ( |
| df_tokens.groupby("labels")[["losses"]] |
| .agg(["count", "mean", "median", "sum"]) |
| .droplevel(level=0, axis=1) |
| .sort_values(by="mean", ascending=False) |
| .reset_index() |
| ) |
|
|
|
|
| class LossesPage(Page): |
| name = "Loss by Token/Label" |
| icon = "sort-alpha-down" |
|
|
| def render(self, context: Context): |
| st.title(self.name) |
| with st.expander("💡", expanded=True): |
| st.write("Show count, mean and median loss per token and label.") |
| st.write( |
| "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." |
| ) |
|
|
| col1, _, col2 = st.columns([8, 1, 6]) |
|
|
| with col1: |
| st.subheader("💬 Loss by Token") |
|
|
| st.session_state["_merge_tokens"] = st.checkbox( |
| "Merge tokens", value=True, key="merge_tokens" |
| ) |
| loss_by_token = ( |
| get_loss_by_token(context.df_tokens_merged) |
| if st.session_state["merge_tokens"] |
| else get_loss_by_token(context.df_tokens_cleaned) |
| ) |
| aggrid_interactive_table(loss_by_token.round(3)) |
| |
| |
| |
|
|
| st.write( |
| "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" |
| ) |
|
|
| with col2: |
| st.subheader("🏷️ Loss by Label") |
| loss_by_label = get_loss_by_label(context.df_tokens_cleaned) |
| AgGrid(loss_by_label.round(3), height=200) |
|
|