| |
| import os, shutil |
|
|
| |
| os.environ["HF_HOME"]="/tmp/hf" |
| os.environ["TRANSFORMERS_CACHE"]="/tmp/hf/transformers" |
| os.environ["TORCH_HOME"]="/tmp/torch" |
| os.environ["PIP_NO_CACHE_DIR"]="1" |
| for d in ["/home/user/.cache", "/home/user/.fastai", "/home/user/.torch"]: |
| shutil.rmtree(d, ignore_errors=True) |
|
|
| from functools import lru_cache |
| from huggingface_hub import hf_hub_download |
| from fastai.vision.all import load_learner |
| import gradio as gr |
|
|
| REPO_ID = "daleef/my-fastai-bear" |
| FILENAME = "model.pkl" |
|
|
| @lru_cache(maxsize=1) |
| def get_learner(): |
| pkl_path = hf_hub_download( |
| repo_id=REPO_ID, |
| filename=FILENAME, |
| local_dir="/tmp/model", |
| local_dir_use_symlinks=False |
| ) |
| return load_learner(pkl_path) |
|
|
| def classify_image(img): |
| learn = get_learner() |
| pred, idx, probs = learn.predict(img) |
| classes = learn.dls.vocab |
| return {c: float(probs[i]) for i, c in enumerate(classes)} |
|
|
| demo = gr.Interface( |
| fn=classify_image, |
| inputs=gr.Image(type="pil", height=192, width=192, label="Upload an image"), |
| outputs=gr.Label(num_top_classes=3), |
| title="Fastai Bear Classifier", |
| description="Upload a bear image to classify." |
| |
| |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |