| """ |
| Demo is based on https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html |
| """ |
| from sklearn.svm import SVC |
| from sklearn.datasets import load_digits |
| from sklearn.feature_selection import RFE |
| import matplotlib.pyplot as plt |
|
|
| |
| digits = load_digits() |
| X = digits.images.reshape((len(digits.images), -1)) |
| y = digits.target |
|
|
| |
| svc = SVC(kernel="linear", C=1) |
|
|
|
|
| def recursive_feature_elimination(n_features_to_select, step, esimator=svc): |
| |
| fig = plt.figure() |
| rfe = RFE(estimator=esimator, n_features_to_select=1, step=1) |
| |
| |
| rfe.fit(X, y) |
| ranking = rfe.ranking_.reshape(digits.images[0].shape) |
|
|
| |
| plt.matshow(ranking, cmap=plt.cm.Blues) |
| plt.colorbar() |
| plt.title("Ranking of pixels with RFE") |
| |
| return plt |
|
|
|
|
| import gradio as gr |
|
|
| title = " Illustration of Recursive feature elimination.🌲 " |
|
|
| with gr.Blocks(title=title) as demo: |
| gr.Markdown(f"# {title}") |
| gr.Markdown( |
| "This example demonstrates recursive feature elimination. <br>" |
| "Dataset is `load_digits()` which is images of size 8x8 images of hand-written digits. <br>" |
| "**Parameters** <br> **Number of features to select**: Represents the features left at the end of feature selection process. <br>" |
| "**Step**: Number of feature to remove at each iteration, least important are removed. <br>" |
| ) |
|
|
| gr.Markdown( |
| "Support vector classifier is used as estimator to rank features. <br>" |
| ) |
|
|
| gr.Markdown( |
| "Demo is based on [sklearn docs](https://scikit-learn.org/stable/auto_examples/feature_selection/plot_rfe_digits.html)." |
| ) |
| with gr.Row(): |
| n_features_to_select = gr.Slider( |
| minimum=0, maximum=20, step=1, value=1, label="Number of features to select" |
| ) |
| step = gr.Slider(minimum=0, maximum=20, step=1, value=1, label="Step") |
|
|
| btn = gr.Button(value="Submit") |
|
|
| btn.click( |
| recursive_feature_elimination, |
| inputs=[n_features_to_select, step], |
| outputs=gr.Plot( |
| label="Recursive feature elimination of pixels in digit classification" |
| ), |
| ) |
|
|
| gr.Markdown( |
| "Plot shows the importance of each pixel in the classification of the digits. <br>" |
| ) |
|
|
| demo.launch() |
|
|