| |
|
|
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import glob, os |
| import shoe_outlines_lib as sol |
| import matplotlib.pyplot as plt |
| import onnxruntime |
| import cv2 |
|
|
| imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| imagenet_means = np.array(imagenet_stats[0], dtype=np.float32)[:, None, None] |
| imagenet_stds = np.array(imagenet_stats[1], dtype=np.float32)[:, None, None] |
| sz = (160, 256) |
|
|
| |
| ort_session = onnxruntime.InferenceSession('shod-model.onnx') |
|
|
|
|
| def csv2image_fig(csv_file): |
| df = sol.csv2dfs([csv_file])[0] |
| fname = df.name |
| df = pd.concat([df, df.iloc[[0]]], ignore_index=True) |
| df = sol.norm_by_x(df) |
| image = sol.coordsdf2image(df) |
| fig = plt.figure(figsize=(2, 4)) |
| plt.plot(df['x'], df['y'], marker='', linestyle='-', color='b', label='Line') |
| plt.fill(df['x'], df['y'], color='blue', alpha=0.2) |
| plt.axis('equal') |
| plt.axis('off') |
| plt.gca().invert_yaxis() |
| return image, fig, fname |
|
|
| |
| def get_predictions(images, bs=8): |
| ''' class 0 is "No shoe", class 1 is "Shoe" ''' |
|
|
| def _softmax(logits): |
| exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) |
| return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
| |
| if isinstance(images, np.ndarray): images = [images] |
| |
| images = np.stack([cv2.resize(image, sz) for image in images]) |
| images = images.transpose(0,3,1,2).astype(np.float32) |
| images = (images / 255.0 - imagenet_means) / imagenet_stds |
|
|
| for b in range(0, len(images), bs): |
| ort_inputs = {ort_session.get_inputs()[0].name: images[b:b+bs]} |
| preds = ort_session.run(None, ort_inputs)[0] |
| all_preds = preds if b==0 else np.concatenate((all_preds, preds)) |
| confidences = _softmax(all_preds)[:,1] |
| |
| return confidences |
|
|
|
|
| css = """ |
| h1 { |
| text-align: center; |
| display:block; |
| vertical-align: middle; |
| } |
| #title-column { |
| padding: 0px !important; /* Remove padding from the parent column */ |
| gap: 0px !important; /* Ensure gap is zero */ |
| } |
| #title-and-subtitle { |
| margin: 0px !important; |
| padding: 0px !important; |
| } |
| #title-and-subtitle .prose h1 { |
| margin: 0px !important; |
| padding: 0px !important; |
| } |
| #title-and-subtitle .prose p { |
| margin: 0px !important; |
| padding: 0px !important; |
| color: gray !important; /* Added */ |
| text-align: center !important; /* Added */ |
| font-style: italic !important; /* Added */ |
| } |
| .logo { |
| max-height: 128px; |
| display: inline-block; |
| vertical-align: middle; |
| } |
| .gradio-container { |
| width: 1200px !important; /* Use !important to override defaults if needed */ |
| margin: 0 auto; |
| } |
| """ |
|
|
| with gr.Blocks(css=css) as app: |
| with gr.Column(): |
| with gr.Row(): |
| gr.Image( |
| value="paleostep-logo-cropped-128.png", |
| interactive=False, |
| show_label=False, |
| show_download_button=False, |
| show_share_button=False, |
| container=False, |
| show_fullscreen_button=False, |
| elem_id="logo", |
| ) |
| with gr.Row(): |
| with gr.Column(elem_id="title-column"): |
| gr.Markdown(""" |
| # STEP: Shod Track Estimated Percentage |
| <p style='color: gray; text-align: center; font-style: italic; margin: 0; padding: 0;'>Mysteriously Accurate Rim Curvature INdex</p> |
| """, elem_id="title-and-subtitle") |
|
|
| |
| with gr.Tab('Single outline classification'): |
| with gr.Row(): |
| gr_input = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="single", label="Upload Outline File") |
|
|
| with gr.Row(): |
| gr.Label(value="Upload a .csv/.xlsx/.json file", visible=True, show_label=False) |
|
|
| with gr.Row(): |
| gr_plot = gr.Plot(label="Outline Plot", show_label=True, visible=False) |
|
|
| with gr.Row(): |
| gr_label = gr.Label(label="Classification", visible=False, show_label=False) |
| |
| def _classify_image(csv_file): |
| try: |
| image, fig, fname = csv2image_fig(csv_file) |
| if len(image.shape) == 2: image = np.tile(image[...,None],(1,1,3)) |
| confidence = get_predictions([image]).item() |
| classification = "Shoe" if confidence >= 0.5 else "No shoe" |
| return ( |
| classification, {f"Shoe confidence: {100*confidence:.1f}": confidence}, gr.update(visible=True), |
| fig, gr.update(visible=True, label=fname) |
| ) |
| except Exception as e: |
| return str(e), str(e), gr.update(visible=True), None, gr.update(visible=False) |
| |
| gr_input.upload( |
| fn=_classify_image, |
| inputs=[gr_input], |
| outputs=[gr_label, gr_label, gr_label, gr_plot, gr_plot], |
| ) |
|
|
| gr_input.clear( |
| fn=lambda: (*([None]*2), *([gr.update(visible=False)]*2)), |
| inputs=[], |
| outputs=[gr_label, gr_plot, gr_label, gr_plot], |
| ) |
|
|
|
|
| |
| with gr.Tab('Batch classification'): |
| with gr.Row(): |
| gr_input_batch = gr.File(file_types=['.csv', '.xlsx', '.json'], file_count="multiple", label="Upload Outline File(s)") |
| with gr.Row(): |
| gr.Label(value="Upload multiple .csv/.xlsx/.json files.", visible=True, show_label=False) |
| with gr.Row(visible=True): |
| with gr.Column(): |
| gr_df = gr.Dataframe(label="Outlines", visible=False, show_label=False, row_count=10) |
| gr_results_file = gr.File(visible=False) |
| |
| def _classify_batch(csv_files): |
| try: |
| for f in glob.glob("classification_results_*.csv"): |
| os.remove(f) |
|
|
| dfs = sol.csv2dfs(csv_files) |
| images = [np.tile(sol.coordsdf2image(df)[...,None],(1,1,3)) for df in dfs] |
| confidences = get_predictions(images) |
|
|
| out = [] |
| for df, confidence in zip(dfs,confidences): |
| images.append(sol.coordsdf2image(df)) |
| out.append({ |
| 'Outline file': df.name, |
| 'Points': len(df), |
| 'Confidence': 100*confidence |
| }) |
|
|
| df_out = pd.DataFrame(out) |
| timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S') |
| filename = f"classification_results_{timestamp}.csv" |
| df_out.to_csv(filename, index=False) |
|
|
| return df_out.style.format({'Confidence': '{:.1f}%'}), gr.update(visible=True), gr.update(visible=True, value=filename) |
| |
| except Exception as e: |
| return pd.DataFrame({'Error': [str(e)]}), gr.update(visible=True), gr.update(visible=False) |
|
|
| gr_input_batch.upload( |
| fn=_classify_batch, |
| inputs=[gr_input_batch], |
| outputs=[gr_df, gr_df, gr_results_file], |
| ) |
|
|
| gr_input_batch.clear( |
| fn=lambda: (None, *([gr.update(visible=False)]*2)), |
| inputs=[], |
| outputs=[gr_df, gr_df, gr_results_file], |
| ) |
|
|
|
|
| app.launch( |
| share=False, |
| debug=False, |
| show_api=False |
| ) |
|
|