| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import pytest |
| import socceraction.spadl as spadl |
| import socceraction.xthreat as xt |
| from pandera.typing import DataFrame, Series |
| from pytest_mock import MockerFixture |
| from sklearn.exceptions import NotFittedError |
| from socceraction.spadl import SPADLSchema |
| from socceraction.spadl.config import field_length, field_width |
|
|
|
|
| class TestGridCount: |
| """Tests for counting the number of actions occuring in each grid cell. |
| |
| Grid cells ares represented by 2D pitch coordinates. The (0,0) coordinate |
| corresponds to the bottom left corner of the pitch. The 2D coordinates are |
| mapped to a flat index. For a 2x2 grid, these flat indices are: |
| 0 1 |
| 2 3 |
| """ |
|
|
| N = 2 |
| M = 2 |
|
|
| def test_get_cell_indexes(self) -> None: |
| """It should map pitch coordinates to a 2D cell index.""" |
| x = Series[float]([0, field_length / 2 - 1, field_length]) |
| y = Series[float]([0, field_width / 2 + 1, field_width]) |
| xi, yi = xt._get_cell_indexes(x, y, self.N, self.M) |
| pd.testing.assert_series_equal(xi, pd.Series([0, 0, 1])) |
| pd.testing.assert_series_equal(yi, pd.Series([0, 1, 1])) |
|
|
| def test_get_cell_indexes_out_of_bounds(self) -> None: |
| """It should map out-of-bounds coordinates to the nearest cell index.""" |
| x = Series[float]([-10, field_length + 10]) |
| y = Series[float]([-10, field_width + 10]) |
| xi, yi = xt._get_cell_indexes(x, y, self.N, self.M) |
| pd.testing.assert_series_equal(xi, pd.Series([0, 1])) |
| pd.testing.assert_series_equal(yi, pd.Series([0, 1])) |
|
|
| def test_get_flat_indexes(self) -> None: |
| """It should map pitch coordinates to a flat index.""" |
| x = Series[float]([0, field_length / 2 - 1, field_length / 2 + 1, field_length]) |
| y = Series[float]([0, field_width / 2 + 1, field_width / 2 - 1, field_width]) |
| idx = xt._get_flat_indexes(x, y, self.N, self.M) |
| pd.testing.assert_series_equal(idx, pd.Series([2, 0, 3, 1])) |
|
|
| def test_count(self) -> None: |
| """It should return the number of occurences in each grid cell.""" |
| x = Series[float]([0, field_length / 2 - 1, field_length, field_length + 10]) |
| y = Series[float]([0, field_width / 2 + 1, field_width, field_width + 10]) |
| cnt = xt._count(x, y, self.N, self.M) |
| np.testing.assert_array_equal(cnt, [[1, 2], [1, 0]]) |
|
|
|
|
| class TestModelPersistency: |
| def test_save_model(self, tmp_path: Path) -> None: |
| """It should save a trained xT grid to a JSON file.""" |
| p = tmp_path / "xt_model.json" |
| model = xt.ExpectedThreat() |
| model.xT = np.ones((model.w, model.l)) |
| model.save_model(str(p)) |
| assert p.read_text() == json.dumps(model.xT.tolist()) |
|
|
| def test_save_model_not_fitted(self, tmp_path: Path) -> None: |
| """It should raise an exception when saving an unfitted model.""" |
| p = tmp_path / "xt_model.json" |
| model = xt.ExpectedThreat() |
| with pytest.raises(NotFittedError): |
| model.save_model(str(p)) |
| model.xT = np.zeros((model.w, model.l)) |
| with pytest.raises(NotFittedError): |
| model.save_model(str(p)) |
|
|
| def test_save_model_file_exists(self, tmp_path: Path) -> None: |
| """It should raise an exception when the file exists.""" |
| p = tmp_path / "xt_model.json" |
| p.write_text("create file") |
| model = xt.ExpectedThreat() |
| model.xT = np.ones((model.w, model.l)) |
| with pytest.raises(ValueError): |
| model.save_model(str(p), overwrite=False) |
| model.save_model(str(p), overwrite=True) |
|
|
| def test_load_model(self, tmp_path: Path) -> None: |
| """It should load a saved xT grid from a JSON file.""" |
| |
| gridv = [[0.1, 0.2], [0.1, 0.0]] |
| |
| p = tmp_path / "xt_model.json" |
| p.write_text(json.dumps(gridv)) |
| |
| model = xt.load_model(str(p)) |
| |
| assert model.w == 2 |
| assert model.l == 2 |
| np.testing.assert_array_equal(model.xT, gridv) |
|
|
|
|
| def test_get_move_actions(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should filter passes, dribbles and crosses.""" |
| move_actions = xt.get_move_actions(spadl_actions) |
| assert move_actions.type_id.isin( |
| [ |
| spadl.config.actiontypes.index("pass"), |
| spadl.config.actiontypes.index("dribble"), |
| spadl.config.actiontypes.index("cross"), |
| ] |
| ).all() |
|
|
|
|
| def test_get_successful_move_actions(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should filter successful passes, dribbles and crosses.""" |
| move_actions = xt.get_successful_move_actions(spadl_actions) |
| assert move_actions.type_id.isin( |
| [ |
| spadl.config.actiontypes.index("pass"), |
| spadl.config.actiontypes.index("dribble"), |
| spadl.config.actiontypes.index("cross"), |
| ] |
| ).all() |
| assert (move_actions.result_id == spadl.config.results.index("success")).all() |
|
|
|
|
| def test_action_prob(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should return the proportion of shots and moves for each cell.""" |
| shot_prob, move_prob = xt.action_prob(spadl_actions, 10, 5) |
| assert shot_prob.shape == (5, 10) |
| assert move_prob.shape == (5, 10) |
| assert np.any(shot_prob > 0) |
| assert np.any(move_prob > 0) |
| assert np.all(((move_prob + shot_prob) == 1) | ((move_prob + shot_prob) == 0)) |
|
|
|
|
| def test_scoring_prob(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should return the proportion of successful shots for each cell.""" |
| shots = spadl_actions.type_id == spadl.config.actiontypes.index("shot") |
| goals = shots & (spadl_actions.result_id == spadl.config.results.index("success")) |
| scoring_prob = xt.scoring_prob(spadl_actions, 1, 1) |
| assert scoring_prob.shape == (1, 1) |
| assert sum(goals) / sum(shots) == scoring_prob[0] |
|
|
|
|
| def test_move_transition_matrix() -> None: |
| """It should return the move transition matrix.""" |
| pass_id = spadl.config.actiontypes.index("pass") |
| success_id = spadl.config.results.index("success") |
| spadl_actions = DataFrame[SPADLSchema]( |
| [ |
| { |
| "game_id": 1, |
| "original_event_id": "a", |
| "action_id": 1, |
| "period_id": 1, |
| "time_seconds": 1.0, |
| "team_id": 1, |
| "player_id": 1, |
| "start_x": 10.0, |
| "end_x": 10.0, |
| "start_y": 10.0, |
| "end_y": 10.0, |
| "bodypart_id": 1, |
| "type_id": pass_id, |
| "result_id": success_id, |
| }, |
| { |
| "game_id": 1, |
| "original_event_id": "a", |
| "action_id": 2, |
| "period_id": 1, |
| "time_seconds": 1.2, |
| "team_id": 1, |
| "player_id": 1, |
| "start_x": 10.0, |
| "end_x": 10.0, |
| "start_y": 10.0, |
| "end_y": 10.0, |
| "bodypart_id": 1, |
| "type_id": pass_id, |
| "result_id": success_id, |
| }, |
| ] |
| ) |
| move_mat = xt.move_transition_matrix(spadl_actions, 2, 2) |
| assert np.sum(move_mat) == 1 |
| assert move_mat.shape == (4, 4) |
| |
| assert move_mat[2, 2] == 1 |
|
|
|
|
| def test_xt_model_init() -> None: |
| """It should initialize all instance variables.""" |
| xTModel = xt.ExpectedThreat(l=8, w=6, eps=1e-3) |
| assert xTModel.l == 8 |
| assert xTModel.w == 6 |
| assert xTModel.eps == 1e-3 |
| assert np.sum(xTModel.xT) == 0 |
| assert xTModel.scoring_prob_matrix is None |
| assert xTModel.scoring_prob_matrix is None |
| assert xTModel.shot_prob_matrix is None |
| assert xTModel.move_prob_matrix is None |
| assert xTModel.transition_matrix is None |
| assert len(xTModel.heatmaps) == 0 |
|
|
|
|
| def test_xt_model_fit(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should update all instance variables.""" |
| xTModel = xt.ExpectedThreat() |
| xTModel.fit(spadl_actions) |
| assert xTModel.scoring_prob_matrix is not None |
| assert xTModel.shot_prob_matrix is not None |
| assert xTModel.move_prob_matrix is not None |
| assert xTModel.transition_matrix is not None |
| assert len(xTModel.heatmaps) > 0 |
| assert np.sum(xTModel.xT) > 0 |
|
|
|
|
| def test_xt_model_rate_not_fitted(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should raise a NotFittedError.""" |
| xTModel = xt.ExpectedThreat() |
| with pytest.raises(NotFittedError): |
| xTModel.rate(spadl_actions) |
|
|
|
|
| def test_xt_model_rate(spadl_actions: DataFrame[SPADLSchema]) -> None: |
| """It should rate all successful move actions and assign all other actions NaN.""" |
| xTModel = xt.ExpectedThreat() |
| xTModel.fit(spadl_actions) |
| successful_move_actions_idx = xt.get_successful_move_actions(spadl_actions).index |
| ratings = xTModel.rate(spadl_actions) |
| assert ratings.shape == (len(spadl_actions),) |
| assert np.all(~np.isnan(ratings[successful_move_actions_idx])) |
| assert np.all(np.isnan(np.delete(ratings, successful_move_actions_idx))) |
|
|
|
|
| def test_interpolate_xt_grid_no_scipy(mocker: MockerFixture) -> None: |
| """It should raise an ImportError if scipy is not installed.""" |
| mocker.patch.object(xt, "interp2d", None) |
| xTModel = xt.ExpectedThreat() |
| with pytest.raises(ImportError, match="Interpolation requires scipy to be installed."): |
| xTModel.interpolator() |
|
|
|
|
| @pytest.fixture(scope="session") |
| def xt_model(sb_worldcup_data: pd.HDFStore) -> xt.ExpectedThreat: |
| """Test the xT framework on the StatsBomb World Cup data.""" |
| |
| df_games = sb_worldcup_data["games"].set_index("game_id") |
| |
| actions_ltr = pd.concat( |
| [ |
| spadl.play_left_to_right( |
| sb_worldcup_data[f"actions/game_{game_id}"], game.home_team_id |
| ) |
| for game_id, game in df_games.iterrows() |
| ] |
| ).pipe(DataFrame[SPADLSchema]) |
| |
| xTModel = xt.ExpectedThreat(l=16, w=12) |
| xTModel.fit(actions_ltr) |
| return xTModel |
|
|
|
|
| @pytest.mark.e2e |
| def test_predict(sb_worldcup_data: pd.HDFStore, xt_model: xt.ExpectedThreat) -> None: |
| games = sb_worldcup_data["games"] |
| game = games.iloc[-1] |
| actions = sb_worldcup_data[f"actions/game_{game.game_id}"] |
| ratings = xt_model.rate(actions) |
| assert ratings.dtype is np.dtype(np.float64) |
| assert len(ratings) == len(actions) |
|
|
|
|
| @pytest.mark.e2e |
| def test_predict_with_interpolation( |
| sb_worldcup_data: pd.HDFStore, xt_model: xt.ExpectedThreat |
| ) -> None: |
| games = sb_worldcup_data["games"] |
| game = games.iloc[-1] |
| actions = sb_worldcup_data[f"actions/game_{game.game_id}"] |
| ratings = xt_model.rate(actions, use_interpolation=True) |
| assert ratings.dtype is np.dtype(np.float64) |
| assert len(ratings) == len(actions) |
|
|