| import json |
| import math |
| import os |
|
|
| os.system("pip uninstall -y gradio") |
| os.system("pip install gradio==3.26.0") |
|
|
|
|
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import plotly.express as px |
| from sklearn.datasets import fetch_20newsgroups |
| from sklearn.feature_extraction.text import TfidfVectorizer |
| from sklearn.model_selection import RandomizedSearchCV |
| from sklearn.naive_bayes import ComplementNB |
| from sklearn.pipeline import Pipeline |
|
|
|
|
| CATEGORIES = [ |
| "alt.atheism", |
| "comp.graphics", |
| "comp.os.ms-windows.misc", |
| "comp.sys.ibm.pc.hardware", |
| "comp.sys.mac.hardware", |
| "comp.windows.x", |
| "misc.forsale", |
| "rec.autos", |
| "rec.motorcycles", |
| "rec.sport.baseball", |
| "rec.sport.hockey", |
| "sci.crypt", |
| "sci.electronics", |
| "sci.med", |
| "sci.space", |
| "soc.religion.christian", |
| "talk.politics.guns", |
| "talk.politics.mideast", |
| "talk.politics.misc", |
| "talk.religion.misc", |
| ] |
|
|
|
|
| def shorten_param(param_name): |
| """Remove components' prefixes in param_name.""" |
| if "__" in param_name: |
| return param_name.rsplit("__", 1)[1] |
| return param_name |
|
|
|
|
| def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm): |
| pipeline = Pipeline( |
| [ |
| ("vect", TfidfVectorizer()), |
| ("clf", ComplementNB()), |
| ] |
| ) |
|
|
| parameters_grid = { |
| "vect__max_df": [eval(value) for value in vect__max_df.split(",")], |
| "vect__min_df": [eval(value) for value in vect__min_df.split(",")], |
| "vect__ngram_range": eval(vect__ngram_range), |
| "vect__norm": [value.strip() for value in vect__norm.split(",")], |
| "clf__alpha": np.logspace(-6, 6, 13), |
| } |
|
|
| print(parameters_grid) |
|
|
| data_train = fetch_20newsgroups( |
| subset="train", |
| categories=categories, |
| shuffle=True, |
| random_state=42, |
| remove=("headers", "footers", "quotes"), |
| ) |
|
|
| data_test = fetch_20newsgroups( |
| subset="test", |
| categories=categories, |
| shuffle=True, |
| random_state=42, |
| remove=("headers", "footers", "quotes"), |
| ) |
|
|
| pipeline = Pipeline( |
| [ |
| ("vect", TfidfVectorizer()), |
| ("clf", ComplementNB()), |
| ] |
| ) |
|
|
| random_search = RandomizedSearchCV( |
| estimator=pipeline, |
| param_distributions=parameters_grid, |
| n_iter=40, |
| random_state=0, |
| n_jobs=2, |
| verbose=1, |
| ) |
|
|
| random_search.fit(data_train.data, data_train.target) |
| best_parameters = json.dumps( |
| random_search.best_estimator_.get_params(), |
| indent=4, |
| sort_keys=True, |
| default=str, |
| ) |
|
|
| test_accuracy = random_search.score(data_test.data, data_test.target) |
|
|
| cv_results = pd.DataFrame(random_search.cv_results_) |
| cv_results = cv_results.rename(shorten_param, axis=1) |
|
|
| param_names = [shorten_param(name) for name in parameters_grid.keys()] |
| labels = { |
| "mean_score_time": "CV Score time (s)", |
| "mean_test_score": "CV score (accuracy)", |
| } |
| fig = px.scatter( |
| cv_results, |
| x="mean_score_time", |
| y="mean_test_score", |
| error_x="std_score_time", |
| error_y="std_test_score", |
| hover_data=param_names, |
| labels=labels, |
| ) |
| fig.update_layout( |
| title={ |
| "text": "trade-off between scoring time and mean test score", |
| "y": 0.95, |
| "x": 0.5, |
| "xanchor": "center", |
| "yanchor": "top", |
| } |
| ) |
|
|
| column_results = param_names + ["mean_test_score", "mean_score_time"] |
|
|
| transform_funcs = dict.fromkeys(column_results, lambda x: x) |
| |
| transform_funcs["alpha"] = math.log10 |
| |
| transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1 |
| |
| transform_funcs["ngram_range"] = lambda x: x[1] |
|
|
| fig2 = px.parallel_coordinates( |
| cv_results[column_results].apply(transform_funcs), |
| color="mean_test_score", |
| color_continuous_scale=px.colors.sequential.Viridis_r, |
| labels=labels, |
| ) |
| fig2.update_layout( |
| title={ |
| "text": "Parallel coordinates plot of text classifier pipeline", |
| "y": 0.99, |
| "x": 0.5, |
| "xanchor": "center", |
| "yanchor": "top", |
| } |
| ) |
|
|
| return fig, fig2, best_parameters, test_accuracy |
|
|
|
|
| def load_description(name): |
| with open(f"./descriptions/{name}.md", "r") as f: |
| return f.read() |
|
|
|
|
| AUTHOR = """ |
| Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html) |
| """ |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Soft()) as app: |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("# Sample pipeline for text feature extraction and evaluation") |
| gr.Markdown(load_description("description_part1")) |
| gr.Markdown(load_description("description_part2")) |
| gr.Markdown(AUTHOR) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("""## CATEGORY SELECTION""") |
| gr.Markdown(load_description("description_category_selection")) |
| drop_categories = gr.Dropdown( |
| CATEGORIES, |
| value=["alt.atheism", "talk.religion.misc"], |
| multiselect=True, |
| label="Categories", |
| info="Please select up to two categories that you want to receive training on.", |
| max_choices=2, |
| interactive=True, |
| ) |
| with gr.Row(): |
| with gr.Tab("PARAMETERS GRID"): |
| gr.Markdown(load_description("description_parameter_grid")) |
| with gr.Row(): |
| with gr.Column(): |
| clf__alpha = gr.Textbox( |
| label="Classifier Alpha (clf__alpha)", |
| value="1.e-06, 1.e-05, 1.e-04", |
| info="Due to practical considerations, this parameter was kept constant.", |
| interactive=False, |
| ) |
| vect__max_df = gr.Textbox( |
| label="Vectorizer max_df (vect__max_df)", |
| value="0.2, 0.4, 0.6, 0.8, 1.0", |
| info="Values ranging from 0 to 1.0, separated by a comma.", |
| interactive=True, |
| ) |
| vect__min_df = gr.Textbox( |
| label="Vectorizer min_df (vect__min_df)", |
| value="1, 3, 5, 10", |
| info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.", |
| interactive=True, |
| ) |
| with gr.Column(): |
| vect__ngram_range = gr.Textbox( |
| label="Vectorizer ngram_range (vect__ngram_range)", |
| value="(1, 1), (1, 2)", |
| info="""Tuples of integer values separated by a comma. For example an `ngram_range` of `(1, 1)` means only unigrams, `(1, 2)` means unigrams and bigrams, and `(2, 2)` means only bigrams.""", |
| interactive=True, |
| ) |
| vect__norm = gr.Textbox( |
| label="Vectorizer norm (vect__norm)", |
| value="l1, l2", |
| info="'l1' or 'l2', separated by a comma", |
| interactive=True, |
| ) |
|
|
| with gr.Tab("DESCRIPTION OF PARAMETERS"): |
| gr.Markdown("""### Classifier Alpha""") |
| gr.Markdown(load_description("parameter_grid/alpha")) |
| gr.Markdown("""### Vectorizer max_df""") |
| gr.Markdown(load_description("parameter_grid/max_df")) |
| gr.Markdown("""### Vectorizer min_df""") |
| gr.Markdown(load_description("parameter_grid/min_df")) |
| gr.Markdown("""### Vectorizer ngram_range""") |
| gr.Markdown(load_description("parameter_grid/ngram_range")) |
| gr.Markdown("""### Vectorizer norm""") |
| gr.Markdown(load_description("parameter_grid/norm")) |
|
|
| with gr.Row(): |
| gr.Markdown( |
| """ |
| ## MODEL PIPELINE |
| ```python |
| pipeline = Pipeline( |
| [ |
| ("vect", TfidfVectorizer()), |
| ("clf", ComplementNB()), |
| ] |
| ) |
| ``` |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("""## TRAINING""") |
| with gr.Row(): |
| brn_train = gr.Button("Train").style(container=False) |
|
|
| gr.Markdown("## RESULTS") |
| with gr.Row(): |
| best_parameters = gr.Textbox(label="Best parameters") |
| test_accuracy = gr.Textbox(label="Test accuracy") |
|
|
| plot_trade = gr.Plot(label="") |
| plot_coordinates = gr.Plot(label="") |
|
|
| brn_train.click( |
| train_model, |
| inputs=[ |
| drop_categories, |
| vect__max_df, |
| vect__min_df, |
| vect__ngram_range, |
| vect__norm, |
| ], |
| outputs=[plot_trade, plot_coordinates, best_parameters, test_accuracy], |
| ) |
|
|
| app.launch() |
|
|