| """Inspect your whole dataset, either unfiltered or by id.""" |
| import streamlit as st |
|
|
| from src.subpages.page import Context, Page |
| from src.utils import aggrid_interactive_table, colorize_classes |
|
|
|
|
| class InspectPage(Page): |
| name = "Inspect" |
| icon = "search" |
|
|
| def render(self, context: Context): |
| st.title(self.name) |
| with st.expander("💡", expanded=True): |
| st.write("Inspect your whole dataset, either unfiltered or by id.") |
|
|
| df = context.df_tokens |
| cols = ( |
| "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() |
| ) |
| if "token_type_ids" not in df.columns: |
| cols.remove("token_type_ids") |
| df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols] |
|
|
| if st.checkbox("Filter by id", value=True): |
| ids = list(sorted(map(int, df.ids.unique()))) |
| next_id = st.session_state.get("next_id", 0) |
|
|
| example_id = st.selectbox("Select an example", ids, index=next_id) |
| df = df[df.ids == str(example_id)][1:-1] |
| |
| st.dataframe(colorize_classes(df.round(3).astype(str))) |
|
|
| |
| |
| |
| |
| |
| |
| else: |
| aggrid_interactive_table(df.round(3)) |
|
|