Spaces:
Running
Running
| from pathlib import Path | |
| from PIL import Image | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import altair as alt | |
| from transformers import AutoModelForImageClassification, AutoImageProcessor | |
| modelname = "POrg/ocsai-d-web" | |
| model = AutoModelForImageClassification.from_pretrained(modelname) | |
| model.eval() # Set the model to evaluation mode | |
| image_processor = AutoImageProcessor.from_pretrained(modelname) | |
| prompt_images = { | |
| "Images_11": "./blanks/Images_11_blank.png", | |
| "Images_4": "./blanks/Images_4_blank.png", | |
| "Images_17": "./blanks/Images_17_blank.png", | |
| "Images_9": "./blanks/Images_9_blank.png", | |
| "Images_3": "./blanks/Images_3_blank.png", | |
| "Images_8": "./blanks/Images_8_blank.png", | |
| "Images_13": "./blanks/Images_13_blank.png", | |
| "Images_15": "./blanks/Images_15_blank.png", | |
| "Images_12": "./blanks/Images_12_blank.png", | |
| "Images_7": "./blanks/Images_7_blank.png", | |
| "Images_56": "./blanks/Images_56_blank.png", | |
| "Images_19": "./blanks/Images_19_blank.png", | |
| } | |
| dist = pd.read_csv('./score_norm_distribution.csv', dtype=float) | |
| base_chart = alt.Chart(dist).mark_line().encode( | |
| x='percentile', | |
| y='score_norm' | |
| ) | |
| def get_percentile(score): | |
| return dist[dist['score_norm'] <= score].iloc[-1, 0] | |
| def inverse_scale(logits): | |
| # undo the min-max scaling that was done from the JRT range to 0-1 | |
| scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188} | |
| return logits * (scaler_params['range']) + scaler_params['min'] | |
| def get_predictions(img): | |
| inputs = image_processor(img, return_tensors="pt") | |
| prediction = model(**inputs) | |
| score = prediction.logits[0].detach().numpy()[0] | |
| score = min(max(score, 0), 1) | |
| return { | |
| 'originality': np.round(score, 2), | |
| 'jrt': np.round(inverse_scale(0), 2), | |
| 'percentile': get_percentile(score) | |
| } | |
| def classify_image(img_dict: dict): | |
| # gradio passes a dictionary with background, composite, and layers | |
| # the composite is what we want | |
| img = img_dict['composite'] | |
| if img is None: | |
| return None | |
| p = get_predictions(img.convert('RGB')) | |
| label = f"Percentile: {int(p['percentile'])}" | |
| label_df = pd.DataFrame({'y': [p['originality']], | |
| 'x': [p['percentile']], | |
| 'text': [label]}) | |
| point = alt.Chart(label_df).mark_point( | |
| shape='triangle', | |
| size=200, | |
| filled=True, | |
| color='red' | |
| ).encode( | |
| x='x', | |
| y='y' | |
| ) | |
| txt = alt.Chart(label_df).mark_text( | |
| align='left', | |
| baseline='middle', | |
| dx=10, dy=-10, | |
| fontSize=14 | |
| ).encode( | |
| y='y', | |
| x='x', | |
| text='text' | |
| ) | |
| return base_chart + point + txt | |
| def update_editor(background, img_editor): | |
| # Clear layers and set the selected background | |
| img_editor['background'] = background | |
| img_editor['layers'] = [] | |
| img_editor['composite'] = None | |
| return img_editor | |
| editor = gr.ImageEditor(type='pil', | |
| value=dict( | |
| background=Image.open(prompt_images['Images_11']), | |
| composite=None, | |
| layers=[] | |
| ), | |
| brush=gr.Brush( | |
| default_size=2, | |
| colors=["#000000", '#333333', '#666666'], | |
| color_mode="fixed" | |
| ), | |
| transforms=[], | |
| sources=('upload', 'clipboard'), | |
| layers=False | |
| ) | |
| examples = [] | |
| for k, v in prompt_images.items(): | |
| examples.append([dict(background=Image.open(v), composite=None, layers=[])]) | |
| demo = gr.Interface(fn=classify_image, | |
| inputs=[editor], | |
| outputs=gr.Plot(), | |
| title="Ocsai-D", | |
| description="Complete the drawing and classify the originality. Choose the brush icon below the image to start editing.\n\nModel from *A Comparison of Supervised and Unsupervised Learning Methods in Automated Scoring of Figural Tests of Creativity* ([preprint](http://dx.doi.org/10.13140/RG.2.2.26865.25444)).\n\nExamples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)).", | |
| examples=examples, | |
| cache_examples=False | |
| ) | |
| demo.launch(debug=True) | |