| """ |
| Baseball statistics application with txtai and Streamlit. |
| |
| Install txtai and streamlit (>= 1.23) to run: |
| pip install txtai streamlit |
| """ |
|
|
| import datetime |
| import math |
| import os |
| import random |
|
|
| import altair as alt |
| import numpy as np |
| import pandas as pd |
| import streamlit as st |
|
|
| from txtai import Embeddings |
|
|
|
|
| class Stats: |
| """ |
| Base stats class. Contains methods for loading, indexing and searching baseball stats. |
| """ |
|
|
| def __init__(self): |
| """ |
| Creates a new Stats instance. |
| """ |
|
|
| |
| self.columns = self.loadcolumns() |
|
|
| |
| self.stats = self.load() |
|
|
| |
| self.names = self.loadnames() |
|
|
| |
| self.vectors, self.data, self.maxyear, self.embeddings = self.index() |
|
|
| def loadcolumns(self): |
| """ |
| Returns a list of data columns. |
| |
| Returns: |
| list of columns |
| """ |
|
|
| raise NotImplementedError |
|
|
| def load(self): |
| """ |
| Loads and returns raw stats. |
| |
| Returns: |
| stats |
| """ |
|
|
| raise NotImplementedError |
|
|
| def metric(self): |
| """ |
| Primary metric column. |
| |
| Returns: |
| metric column name |
| """ |
|
|
| raise NotImplementedError |
|
|
| def vector(self, row): |
| """ |
| Build a vector for input row. |
| |
| Args: |
| row: input row |
| |
| Returns: |
| row vector |
| """ |
|
|
| raise NotImplementedError |
|
|
| def loadnames(self): |
| """ |
| Loads a name - player id dictionary. |
| |
| Returns: |
| {player name: player id} |
| """ |
|
|
| |
| names = {} |
| rows = self.stats.sort_values(by=self.metric(), ascending=False)[["nameFirst", "nameLast", "playerID"]].drop_duplicates().reset_index() |
| for x, row in rows.iterrows(): |
| |
| key = f"{row['nameFirst']} {row['nameLast']}" |
| key += f" ({row['playerID']})" if key in names else "" |
|
|
| if key not in names: |
| |
| exponent = 2 if ((len(rows) - x) / len(rows)) >= 0.95 else 1 |
|
|
| |
| score = math.pow(len(self.stats[self.stats["playerID"] == row["playerID"]]), exponent) |
|
|
| |
| names[key] = (row["playerID"], score) |
|
|
| return names |
|
|
| def index(self): |
| """ |
| Builds an embeddings index to stats data. Returns vectors, input data and embeddings index. |
| |
| Returns: |
| vectors, data, embeddings |
| """ |
|
|
| |
| vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()} |
| data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()} |
| maxyear = max(row["yearID"] for _, row in self.stats.iterrows()) |
|
|
| embeddings = Embeddings({"transform": Stats.transform}) |
| embeddings.index((uid, vectors[uid], None) for uid in vectors) |
|
|
| return vectors, data, maxyear, embeddings |
|
|
| def metrics(self, name): |
| """ |
| Looks up a player's active years, best statistical year and key metrics. |
| |
| Args: |
| name: player name |
| |
| Returns: |
| active, best, metrics |
| """ |
|
|
| if name in self.names: |
| |
| stats = self.stats[self.stats["playerID"] == self.names[name][0]] |
|
|
| |
| metrics = stats[["yearID", self.metric()]] |
|
|
| |
| best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0]) |
|
|
| |
| return metrics["yearID"].tolist(), best, metrics |
|
|
| return range(1871, datetime.datetime.today().year), 1950, None |
|
|
| def search(self, name=None, year=None, window=None, row=None, limit=10): |
| """ |
| Runs an embeddings search. This method takes either a player-year or stats row as input. |
| |
| Args: |
| name: player name to search |
| year: year to search |
| window: limit to window recent seasons |
| row: row of stats to search |
| limit: max results to return |
| |
| Returns: |
| list of results |
| """ |
|
|
| if row: |
| query = self.vector(row) |
| else: |
| |
| name = self.names.get(name) |
| query = f"{year}{name[0] if name else name}" |
| query = self.vectors.get(query) |
|
|
| results, ids = [], set() |
| if query is not None: |
| candidates = limit * 100 if window else limit * 5 |
| for uid, _ in self.embeddings.search(query, candidates): |
| |
| if uid[4:] not in ids: |
| result = self.data[uid].copy() |
|
|
| |
| if (not ids and not row) or not window or result["yearID"] > self.maxyear - window: |
| result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml' |
| results.append(result) |
| ids.add(uid[4:]) |
|
|
| if len(ids) >= limit: |
| break |
|
|
| return results |
|
|
| def transform(self, row): |
| """ |
| Transforms a stats row into a vector. |
| |
| Args: |
| row: stats row |
| |
| Returns: |
| vector |
| """ |
|
|
| if isinstance(row, np.ndarray): |
| return row |
|
|
| return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns]) |
|
|
|
|
| class Batting(Stats): |
| """ |
| Batting stats. |
| """ |
|
|
| def loadcolumns(self): |
| return [ |
| "birthMonth", |
| "yearID", |
| "age", |
| "height", |
| "weight", |
| "G", |
| "AB", |
| "R", |
| "H", |
| "1B", |
| "2B", |
| "3B", |
| "HR", |
| "RBI", |
| "SB", |
| "CS", |
| "BB", |
| "SO", |
| "IBB", |
| "HBP", |
| "SH", |
| "SF", |
| "GIDP", |
| "POS", |
| "AVG", |
| "OBP", |
| "TB", |
| "SLG", |
| "OPS", |
| "OPS+", |
| ] |
|
|
| def load(self): |
| |
| players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv") |
| batting = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Batting.csv") |
| fielding = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Fielding.csv") |
|
|
| |
| batting = pd.merge(players, batting, how="inner", on=["playerID"]) |
|
|
| |
| batting = batting[((batting["AB"] + batting["BB"]) >= 350) & (batting["stint"] == 1)] |
|
|
| |
| positions = self.positions(fielding) |
|
|
| |
| batting["age"] = batting["yearID"] - batting["birthYear"] |
| batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1) |
| batting["AVG"] = batting["H"] / batting["AB"] |
| batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"]) |
| batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"] |
| batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"] |
| batting["SLG"] = batting["TB"] / batting["AB"] |
| batting["OPS"] = batting["OBP"] + batting["SLG"] |
| batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100 |
|
|
| return batting |
|
|
| def metric(self): |
| return "OPS+" |
|
|
| def vector(self, row): |
| row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"] |
| row["AVG"] = row["H"] / row["AB"] |
| row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"]) |
| row["SLG"] = row["TB"] / row["AB"] |
| row["OPS"] = row["OBP"] + row["SLG"] |
| row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100 |
|
|
| return self.transform(row) |
|
|
| def positions(self, fielding): |
| """ |
| Derives primary positions for players. |
| |
| Args: |
| fielding: fielding data |
| |
| Returns: |
| {player id: (position, number of games)} |
| """ |
|
|
| positions = {} |
| for _, row in fielding.iterrows(): |
| uid = f'{row["yearID"]}{row["playerID"]}' |
| position = row["POS"] if row["POS"] else 0 |
| if position == "P": |
| position = 1 |
| elif position == "C": |
| position = 2 |
| elif position == "1B": |
| position = 3 |
| elif position == "2B": |
| position = 4 |
| elif position == "3B": |
| position = 5 |
| elif position == "SS": |
| position = 6 |
| elif position == "OF": |
| position = 7 |
|
|
| |
| if uid not in positions or positions[uid][1] < row["G"]: |
| positions[uid] = (position, row["G"]) |
|
|
| return positions |
|
|
| def position(self, positions, row): |
| """ |
| Looks up primary position for player row. |
| |
| Arg: |
| positions: all player positions |
| row: player row |
| |
| Returns: |
| primary player positions |
| """ |
|
|
| uid = f'{row["yearID"]}{row["playerID"]}' |
| return positions[uid][0] if uid in positions else 0 |
|
|
|
|
| class Pitching(Stats): |
| """ |
| Pitching stats. |
| """ |
|
|
| def loadcolumns(self): |
| return [ |
| "birthMonth", |
| "yearID", |
| "age", |
| "height", |
| "weight", |
| "W", |
| "L", |
| "G", |
| "GS", |
| "CG", |
| "SHO", |
| "SV", |
| "IPouts", |
| "H", |
| "ER", |
| "HR", |
| "BB", |
| "SO", |
| "BAOpp", |
| "ERA", |
| "IBB", |
| "WP", |
| "HBP", |
| "BK", |
| "BFP", |
| "GF", |
| "R", |
| "SH", |
| "SF", |
| "GIDP", |
| "WHIP", |
| "WADJ", |
| ] |
|
|
| def load(self): |
| |
| players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv") |
| pitching = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Pitching.csv") |
|
|
| |
| pitching = pd.merge(players, pitching, how="inner", on=["playerID"]) |
|
|
| |
| pitching = pitching[(pitching["G"] >= 20) & (pitching["stint"] == 1)] |
|
|
| |
| pitching["age"] = pitching["yearID"] - pitching["birthYear"] |
| pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3) |
| pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"]) |
|
|
| return pitching |
|
|
| def metric(self): |
| return "WADJ" |
|
|
| def vector(self, row): |
| row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None |
| row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None |
|
|
| return self.transform(row) |
|
|
|
|
| class Application: |
| """ |
| Main application. |
| """ |
|
|
| def __init__(self): |
| """ |
| Creates a new application. |
| """ |
|
|
| |
| self.batting = Batting() |
|
|
| |
| self.pitching = Pitching() |
|
|
| def run(self): |
| """ |
| Runs a Streamlit application. |
| """ |
|
|
| st.set_page_config( |
| layout="wide", |
| page_title="Baseball Stats", |
| page_icon="⚾" |
| ) |
| |
| st.title("⚾ Baseball Stats") |
| st.markdown( |
| """ |
| This application finds the best matching players using vector search with [txtai](https://github.com/neuml/txtai). |
| Raw data is from the [Lahman Baseball Database](https://sabr.org/lahman-database/). Read [this |
| article](https://medium.com/neuml/explore-baseball-history-with-vector-search-5778d98d6846) for more details. |
| """ |
| ) |
|
|
| player, search = st.tabs(["Player", "Search"]) |
|
|
| |
| with player: |
| self.player() |
|
|
| |
| with search: |
| self.search() |
|
|
| def player(self): |
| """ |
| Player tab. |
| """ |
|
|
| st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.") |
|
|
| |
| params = self.params() |
|
|
| |
| category = self.category(params.get("category"), "category") |
| stats = self.batting if category == "Batting" else self.pitching |
|
|
| |
| name = self.name(stats.names, params.get("name")) |
|
|
| |
| window = self.window(params.get("window"), "window") |
|
|
| |
| active, best, metrics = stats.metrics(name) |
|
|
| |
| year = self.year(active, params.get("year"), best) |
|
|
| |
| if len(active) > 1: |
| self.chart(category, metrics) |
|
|
| |
| results = stats.search(name, year, window) |
|
|
| |
| self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:]) |
|
|
| |
| st.query_params.clear() |
| for key, value in [("category", category), ("name", name), ("window", window), ("year", year)]: |
| if value: |
| st.query_params[key] = value |
|
|
| def search(self): |
| """ |
| Stats search tab. |
| """ |
|
|
| st.markdown("Find players with similar statistics.") |
|
|
| stats, category = None, self.category("Batting", "searchcategory") |
|
|
| |
| window = self.window(None, "searchwindow") |
|
|
| with st.form("search"): |
| if category == "Batting": |
| stats, columns = self.batting, self.batting.columns[:-6] |
| elif category == "Pitching": |
| stats, columns = self.pitching, self.pitching.columns[:-2] |
|
|
| |
| inputs = st.data_editor(pd.DataFrame([dict((column, None) for column in columns)]), hide_index=True).astype(float) |
|
|
| submitted = st.form_submit_button("Search") |
| if submitted: |
| |
| results = stats.search(window=window, row=inputs.to_dict(orient="records")[0]) |
|
|
| |
| self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:]) |
|
|
| def params(self): |
| """ |
| Get application parameters. This method combines URL parameters with session parameters. |
| |
| Returns: |
| parameters |
| """ |
|
|
| |
| params = {x: st.query_params.get(x) for x in ["category", "name", "window", "year"]} |
|
|
| |
| if all(x in st.session_state for x in ["category", "name", "year"]): |
| |
| params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None |
|
|
| |
| params["category"] = st.session_state["category"] |
| params["name"] = st.session_state["name"] |
| params["window"] = st.session_state["window"] |
|
|
| return params |
|
|
| def category(self, category, key): |
| """ |
| Builds category input widget. |
| |
| Args: |
| category: category parameter |
| key: widget key |
| |
| Returns: |
| category component |
| """ |
|
|
| |
| categories = ["Batting", "Pitching"] |
|
|
| |
| default = categories.index(category) if category and category in categories else 0 |
|
|
| |
| return st.radio("Stat", categories, index=default, horizontal=True, key=key) |
|
|
| def window(self, window, key): |
| """ |
| Limit results to last N seasons. |
| |
| Args: |
| window: limit to window seasons |
| key: widget key |
| |
| Returns: |
| window component |
| """ |
|
|
| |
| window = st.text_input("Limit to last N seasons", value=window, key=key) |
|
|
| |
| if window and (not window.isnumeric() or int(window) < 1): |
| st.error("Window must be a number greater or equal to 1") |
| return None |
|
|
| |
| return int(window) if window else window |
|
|
| def name(self, names, name): |
| """ |
| Builds name input widget. |
| |
| Args: |
| names: list of all allowable names |
| |
| Returns: |
| name component |
| """ |
|
|
| |
| name = name if name and name in names else random.choices(list(names.keys()), weights=[names[x][1] for x in names])[0] |
|
|
| |
| names = sorted(names) |
|
|
| |
| return st.selectbox("Name", names, names.index(name), key="name") |
|
|
| def year(self, years, year, best): |
| """ |
| Builds year input widget. |
| |
| Args: |
| years: active years for a player |
| year: year parameter |
| best: default to best year if year is invalid |
| |
| Returns: |
| year component |
| """ |
|
|
| |
| year = int(year) if year and year.isdigit() and int(year) in years else best |
|
|
| |
| return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0]) |
|
|
| def chart(self, category, metrics): |
| """ |
| Displays a metric chart. |
| |
| Args: |
| category: Batting or Pitching |
| metrics: player metrics to plot |
| """ |
|
|
| |
| metric = self.batting.metric() if category == "Batting" else self.pitching.metric() |
|
|
| |
| metrics["yearID"] = metrics["yearID"].astype(str) |
|
|
| |
| chart = ( |
| alt.Chart(metrics) |
| .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75) |
| .encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False))) |
| ) |
|
|
| |
| rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})") |
|
|
| |
| chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False) |
|
|
| |
| st.altair_chart(chart + rule, theme="streamlit", width="stretch") |
|
|
| def table(self, results, columns): |
| """ |
| Displays a list of results as a table. |
| |
| Args: |
| results: list of results |
| columns: column names |
| """ |
|
|
| if results: |
| st.dataframe( |
| results, |
| column_order=columns, |
| column_config={ |
| "link": st.column_config.LinkColumn("Link", width="small", display_text=":material/open_in_new:"), |
| "yearID": st.column_config.NumberColumn("Year", format="%d"), |
| "nameFirst": "First", |
| "nameLast": "Last", |
| "teamID": "Team", |
| "age": "Age", |
| "weight": "Weight", |
| "height": "Height", |
| }, |
| ) |
| else: |
| st.write("Player-Year not found") |
|
|
|
|
| @st.cache_resource(show_spinner=False) |
| def create(): |
| """ |
| Creates and caches a Streamlit application. |
| |
| Returns: |
| Application |
| """ |
|
|
| return Application() |
|
|
|
|
| if __name__ == "__main__": |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| app = create() |
| app.run() |