| """Script to create the model artifact |
| |
| Trains a simple logistic regression with grid search on a synthetic dataset and |
| stores the model in a pickle file. |
| |
| """ |
|
|
| import joblib |
| from sklearn.datasets import make_classification |
| from sklearn.linear_model import SGDClassifier |
| from sklearn.model_selection import GridSearchCV |
|
|
|
|
| SEED = 0 |
| FILENAME = 'sklearn_model.joblib' |
|
|
|
|
| def get_data(): |
| X, y = make_classification(n_samples=2000, random_state=SEED) |
| return X, y |
|
|
|
|
| def get_model(**kwargs): |
| model = SGDClassifier(random_state=SEED) |
| model.set_params(**kwargs) |
| return model |
|
|
|
|
| def get_hparams(): |
| hparams = { |
| 'penalty': ['l1', 'l2'], |
| 'alpha': [0.00001, 0.0001, 0.001], |
| } |
| return hparams |
|
|
|
|
| def grid_search(model, X, y, hparams): |
| search = GridSearchCV(model, hparams, cv=5, scoring='accuracy') |
| search.fit(X, y) |
| return search |
|
|
|
|
| def train(model, X, y, hparams): |
| search = grid_search(model, X, y, hparams=hparams) |
| print(f"Best accuracy: {100 * search.best_score_:.1f}%") |
| print(f"Best parameters: {search.best_params_}") |
| return search.best_estimator_ |
|
|
|
|
| def save_model(model, filename): |
| joblib.dump(model, filename) |
| print(f"Stored model in '{filename}'") |
|
|
|
|
| def main(): |
| X, y = get_data() |
| model = get_model() |
| hparams = get_hparams() |
| model_trained = train(model, X, y, hparams=hparams) |
| save_model(model_trained, FILENAME) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|