| from fastai.vision.all import * |
| import gradio as gr |
|
|
| def is_cat(x): |
| return x[0].isupper() |
|
|
| learn = load_learner('model.pkl') |
|
|
| categories = ('Cat', 'Dog') |
|
|
| prompts = [ |
| "# Definitely a {}!", |
| "# Well, that must be a {}!", |
| "# Oh, that's a {}!", |
| "# That's a {}!", |
| "# Looks like a {} to me!", |
| ] |
|
|
| failure_prompts = [ |
| "# I'm not sure what that is.", |
| "# I don't know what that thing is.", |
| "# I've never seen that before.", |
| "# Looks familiar, but unsure.", |
| "# Something, something?", |
| "# Beats me.", |
| ] |
|
|
| def classify_image(img): |
| pred,idx,probs = learn.predict(img) |
| return dict(zip(categories, map(float,probs))) |
|
|
| def calculate(confidence_threshold, img): |
| classifications = classify_image(img) |
| classification = random.choice(failure_prompts) |
| for key, value in classifications.items(): |
| if value > confidence_threshold: |
| classification = random.choice(prompts).format(key) |
| break |
|
|
| return [classification, classifications] |
|
|
|
|
| with gr.Blocks() as ui: |
|
|
| heading = gr.Markdown(" # Dog or Cat?", render=False) |
| results = gr.Label(value="Waiting to receive image.", label="Details", show_label=False, render=False) |
|
|
| with gr.Row(equal_height=True): |
|
|
| with gr.Column(): |
| gr.Markdown("Upload an image of a cat or a dog.") |
|
|
| with gr.Group(): |
| image = gr.Image(show_label=False, height=300) |
| confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Confidence Threshold") |
| btn = gr.Button(value="Classify") |
| btn.click(calculate, inputs=[confidence, image], outputs=[heading, results]) |
|
|
| with gr.Column(): |
| gr.Markdown("Then wait for the magic to happen") |
| with gr.Group(): |
| results.render() |
| heading.render() |
|
|
| gr.Markdown(" # Examples") |
| with gr.Group(): |
| gr.Examples(inputs=image, examples=['images/cat1.jpeg', 'images/cat2.jpeg', 'images/cat3.jpeg', 'images/dog1.jpeg', 'images/dog2.jpeg', 'images/dog3.jpeg']) |
|
|
| if __name__ == "__main__": |
| ui.launch() |